Skip to content

Commit e45e6c2

Browse files
michaelraczyckiMichal Raczycki
and
Michal Raczycki
authored
increase zerosum readability issue #6459 (#6522)
* added n_zerosum_axes and added backwards compatibility for previous parameter name * fixed typo that caused the zerosum_axes param not to be saved correctly in the new param * adapted tests to use n_zerosum_axes --------- Co-authored-by: Michal Raczycki <[email protected]>
1 parent ec30a2f commit e45e6c2

File tree

2 files changed

+82
-68
lines changed

2 files changed

+82
-68
lines changed

pymc/distributions/multivariate.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2394,7 +2394,7 @@ class ZeroSumNormal(Distribution):
23942394
ZeroSumNormal distribution, i.e Normal distribution where one or
23952395
several axes are constrained to sum to zero.
23962396
By default, the last axis is constrained to sum to zero.
2397-
See `zerosum_axes` kwarg for more details.
2397+
See `n_zerosum_axes` kwarg for more details.
23982398
23992399
.. math::
24002400
@@ -2411,9 +2411,10 @@ class ZeroSumNormal(Distribution):
24112411
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
24122412
Defaults to 1 if not specified.
24132413
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2414-
zerosum_axes: int, defaults to 1
2414+
n_zerosum_axes: int, defaults to 1
24152415
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
24162416
Defaults to 1, i.e the rightmost axis.
2417+
zerosum_axes: int, deprecated please use n_zerosum_axes as its successor
24172418
dims: sequence of strings, optional
24182419
Dimension names of the distribution. Works the same as for other PyMC distributions.
24192420
Necessary if ``shape`` is not passed.
@@ -2426,13 +2427,13 @@ class ZeroSumNormal(Distribution):
24262427
``sigma`` has to be a scalar, to ensure the zero-sum constraint.
24272428
The ability to specify a vector of ``sigma`` may be added in future versions.
24282429
2429-
``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``,
2430+
``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``,
24302431
just use ``pm.Normal``.
24312432
24322433
Examples
24332434
--------
24342435
Define a `ZeroSumNormal` variable, with `sigma=1` and
2435-
`zerosum_axes=1` by default::
2436+
`n_zerosum_axes=1` by default::
24362437
24372438
COORDS = {
24382439
"regions": ["a", "b", "c"],
@@ -2444,33 +2445,46 @@ class ZeroSumNormal(Distribution):
24442445
24452446
with pm.Model(coords=COORDS) as m:
24462447
# the zero sum axes will be 'answers' and 'regions'
2447-
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2)
2448+
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2)
24482449
24492450
with pm.Model(coords=COORDS) as m:
24502451
# the zero sum axes will be the last two
2451-
v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2)
2452+
v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2)
24522453
"""
24532454
rv_type = ZeroSumNormalRV
24542455

2455-
def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwargs):
2456+
def __new__(
2457+
cls, *args, zerosum_axes=None, n_zerosum_axes=None, support_shape=None, dims=None, **kwargs
2458+
):
2459+
if zerosum_axes is not None:
2460+
n_zerosum_axes = zerosum_axes
2461+
warnings.warn(
2462+
"The 'zerosum_axes' parameter is deprecated. Use 'n_zerosum_axes' instead.",
2463+
DeprecationWarning,
2464+
)
24562465
if dims is not None or kwargs.get("observed") is not None:
2457-
zerosum_axes = cls.check_zerosum_axes(zerosum_axes)
2466+
n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes)
24582467

24592468
support_shape = get_support_shape(
24602469
support_shape=support_shape,
24612470
shape=None, # Shape will be checked in `cls.dist`
24622471
dims=dims,
24632472
observed=kwargs.get("observed", None),
2464-
ndim_supp=zerosum_axes,
2473+
ndim_supp=n_zerosum_axes,
24652474
)
24662475

24672476
return super().__new__(
2468-
cls, *args, zerosum_axes=zerosum_axes, support_shape=support_shape, dims=dims, **kwargs
2477+
cls,
2478+
*args,
2479+
n_zerosum_axes=n_zerosum_axes,
2480+
support_shape=support_shape,
2481+
dims=dims,
2482+
**kwargs,
24692483
)
24702484

24712485
@classmethod
2472-
def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
2473-
zerosum_axes = cls.check_zerosum_axes(zerosum_axes)
2486+
def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
2487+
n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes)
24742488

24752489
sigma = at.as_tensor_variable(floatX(sigma))
24762490
if sigma.ndim > 0:
@@ -2479,41 +2493,41 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
24792493
support_shape = get_support_shape(
24802494
support_shape=support_shape,
24812495
shape=kwargs.get("shape"),
2482-
ndim_supp=zerosum_axes,
2496+
ndim_supp=n_zerosum_axes,
24832497
)
24842498

24852499
if support_shape is None:
2486-
if zerosum_axes > 0:
2500+
if n_zerosum_axes > 0:
24872501
raise ValueError("You must specify dims, shape or support_shape parameter")
24882502
# TODO: edge-case doesn't work for now, because at.stack in get_support_shape fails
24892503
# else:
24902504
# support_shape = () # because it's just a Normal in that case
24912505
support_shape = at.as_tensor_variable(intX(support_shape))
24922506

2493-
assert zerosum_axes == at.get_vector_length(
2507+
assert n_zerosum_axes == at.get_vector_length(
24942508
support_shape
2495-
), "support_shape has to be as long as zerosum_axes"
2509+
), "support_shape has to be as long as n_zerosum_axes"
24962510

24972511
return super().dist(
2498-
[sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs
2512+
[sigma], n_zerosum_axes=n_zerosum_axes, support_shape=support_shape, **kwargs
24992513
)
25002514

25012515
@classmethod
2502-
def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int:
2503-
if zerosum_axes is None:
2504-
zerosum_axes = 1
2505-
if not isinstance(zerosum_axes, int):
2506-
raise TypeError("zerosum_axes has to be an integer")
2507-
if not zerosum_axes > 0:
2508-
raise ValueError("zerosum_axes has to be > 0")
2509-
return zerosum_axes
2516+
def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int:
2517+
if n_zerosum_axes is None:
2518+
n_zerosum_axes = 1
2519+
if not isinstance(n_zerosum_axes, int):
2520+
raise TypeError("n_zerosum_axes has to be an integer")
2521+
if not n_zerosum_axes > 0:
2522+
raise ValueError("n_zerosum_axes has to be > 0")
2523+
return n_zerosum_axes
25102524

25112525
@classmethod
2512-
def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
2526+
def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None):
25132527
shape = to_tuple(size) + tuple(support_shape)
25142528
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape))
25152529

2516-
if zerosum_axes > normal_dist.ndim:
2530+
if n_zerosum_axes > normal_dist.ndim:
25172531
raise ValueError("Shape of distribution is too small for the number of zerosum axes")
25182532

25192533
normal_dist_, sigma_, support_shape_ = (
@@ -2522,15 +2536,15 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
25222536
support_shape.type(),
25232537
)
25242538

2525-
# Zerosum-normaling is achieved by subtracting the mean along the given zerosum_axes
2539+
# Zerosum-normaling is achieved by subtracting the mean along the given n_zerosum_axes
25262540
zerosum_rv_ = normal_dist_
2527-
for axis in range(zerosum_axes):
2541+
for axis in range(n_zerosum_axes):
25282542
zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True)
25292543

25302544
return ZeroSumNormalRV(
25312545
inputs=[normal_dist_, sigma_, support_shape_],
25322546
outputs=[zerosum_rv_, support_shape_],
2533-
ndim_supp=zerosum_axes,
2547+
ndim_supp=n_zerosum_axes,
25342548
)(normal_dist, sigma, support_shape)
25352549

25362550

@@ -2544,7 +2558,7 @@ def change_zerosum_size(op, normal_dist, new_size, expand=False):
25442558
new_size = tuple(new_size) + old_size
25452559

25462560
return ZeroSumNormal.rv_op(
2547-
sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size
2561+
sigma=sigma, n_zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size
25482562
)
25492563

25502564

@@ -2555,28 +2569,28 @@ def zerosumnormal_moment(op, rv, *rv_inputs):
25552569

25562570
@_default_transform.register(ZeroSumNormalRV)
25572571
def zerosum_default_transform(op, rv):
2558-
zerosum_axes = tuple(np.arange(-op.ndim_supp, 0))
2559-
return ZeroSumTransform(zerosum_axes)
2572+
n_zerosum_axes = tuple(np.arange(-op.ndim_supp, 0))
2573+
return ZeroSumTransform(n_zerosum_axes)
25602574

25612575

25622576
@_logprob.register(ZeroSumNormalRV)
25632577
def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
25642578
(value,) = values
25652579
shape = value.shape
2566-
zerosum_axes = op.ndim_supp
2580+
n_zerosum_axes = op.ndim_supp
25672581

2568-
_deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1)
2582+
_deg_free_support_shape = at.inc_subtensor(shape[-n_zerosum_axes:], -1)
25692583
_full_size = at.prod(shape)
25702584
_degrees_of_freedom = at.prod(_deg_free_support_shape)
25712585

25722586
zerosums = [
25732587
at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9))
2574-
for axis in range(zerosum_axes)
2588+
for axis in range(n_zerosum_axes)
25752589
]
25762590

25772591
out = at.sum(
25782592
pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size,
2579-
axis=tuple(np.arange(-zerosum_axes, 0)),
2593+
axis=tuple(np.arange(-n_zerosum_axes, 0)),
25802594
)
25812595

2582-
return check_parameters(out, *zerosums, msg="mean(value, axis=zerosum_axes) = 0")
2596+
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")

pymc/tests/distributions/test_multivariate.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,16 +1014,16 @@ def test_mv_normal_moment(self, mu, cov, size, expected):
10141014
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
10151015

10161016
@pytest.mark.parametrize(
1017-
"shape, zerosum_axes, expected",
1017+
"shape, n_zerosum_axes, expected",
10181018
[
10191019
((2, 5), None, np.zeros((2, 5))),
10201020
((2, 5, 6), 2, np.zeros((2, 5, 6))),
10211021
((2, 5, 6), 3, np.zeros((2, 5, 6))),
10221022
],
10231023
)
1024-
def test_zerosum_normal_moment(self, shape, zerosum_axes, expected):
1024+
def test_zerosum_normal_moment(self, shape, n_zerosum_axes, expected):
10251025
with pm.Model() as model:
1026-
pm.ZeroSumNormal("x", shape=shape, zerosum_axes=zerosum_axes)
1026+
pm.ZeroSumNormal("x", shape=shape, n_zerosum_axes=n_zerosum_axes)
10271027
assert_moment_is_expected(model, expected)
10281028

10291029
@pytest.mark.parametrize(
@@ -1405,16 +1405,16 @@ def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=
14051405
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
14061406

14071407
@pytest.mark.parametrize(
1408-
"dims, zerosum_axes",
1408+
"dims, n_zerosum_axes",
14091409
[
14101410
(("regions", "answers"), None),
14111411
(("regions", "answers"), 1),
14121412
(("regions", "answers"), 2),
14131413
],
14141414
)
1415-
def test_zsn_dims(self, dims, zerosum_axes):
1415+
def test_zsn_dims(self, dims, n_zerosum_axes):
14161416
with pm.Model(coords=self.coords) as m:
1417-
v = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
1417+
v = pm.ZeroSumNormal("v", dims=dims, n_zerosum_axes=n_zerosum_axes)
14181418
s = pm.sample(10, chains=1, tune=100)
14191419

14201420
# to test forward graph
@@ -1428,24 +1428,24 @@ def test_zsn_dims(self, dims, zerosum_axes):
14281428
)
14291429

14301430
ndim_supp = v.owner.op.ndim_supp
1431-
zerosum_axes = np.arange(-ndim_supp, 0)
1431+
n_zerosum_axes = np.arange(-ndim_supp, 0)
14321432
nonzero_axes = np.arange(v.ndim - ndim_supp)
14331433
for samples in [
14341434
s.posterior.v,
14351435
random_samples,
14361436
]:
1437-
self.assert_zerosum_axes(samples, zerosum_axes)
1437+
self.assert_zerosum_axes(samples, n_zerosum_axes)
14381438
self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False)
14391439

14401440
@pytest.mark.parametrize(
1441-
"zerosum_axes",
1441+
"n_zerosum_axes",
14421442
(None, 1, 2),
14431443
)
1444-
def test_zsn_shape(self, zerosum_axes):
1444+
def test_zsn_shape(self, n_zerosum_axes):
14451445
shape = (len(self.coords["regions"]), len(self.coords["answers"]))
14461446

14471447
with pm.Model(coords=self.coords) as m:
1448-
v = pm.ZeroSumNormal("v", shape=shape, zerosum_axes=zerosum_axes)
1448+
v = pm.ZeroSumNormal("v", shape=shape, n_zerosum_axes=n_zerosum_axes)
14491449
s = pm.sample(10, chains=1, tune=100)
14501450

14511451
# to test forward graph
@@ -1459,17 +1459,17 @@ def test_zsn_shape(self, zerosum_axes):
14591459
)
14601460

14611461
ndim_supp = v.owner.op.ndim_supp
1462-
zerosum_axes = np.arange(-ndim_supp, 0)
1462+
n_zerosum_axes = np.arange(-ndim_supp, 0)
14631463
nonzero_axes = np.arange(v.ndim - ndim_supp)
14641464
for samples in [
14651465
s.posterior.v,
14661466
random_samples,
14671467
]:
1468-
self.assert_zerosum_axes(samples, zerosum_axes)
1468+
self.assert_zerosum_axes(samples, n_zerosum_axes)
14691469
self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False)
14701470

14711471
@pytest.mark.parametrize(
1472-
"error, match, shape, support_shape, zerosum_axes",
1472+
"error, match, shape, support_shape, n_zerosum_axes",
14731473
[
14741474
(
14751475
ValueError,
@@ -1485,14 +1485,14 @@ def test_zsn_shape(self, zerosum_axes):
14851485
(3, 4),
14861486
(3, 4),
14871487
None,
1488-
), # doesn't work because zerosum_axes = 1 by default
1488+
), # doesn't work because n_zerosum_axes = 1 by default
14891489
],
14901490
)
1491-
def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
1491+
def test_zsn_fail_axis(self, error, match, shape, support_shape, n_zerosum_axes):
14921492
with pytest.raises(error, match=match):
14931493
with pm.Model() as m:
14941494
_ = pm.ZeroSumNormal(
1495-
"v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes
1495+
"v", shape=shape, support_shape=support_shape, n_zerosum_axes=n_zerosum_axes
14961496
)
14971497

14981498
@pytest.mark.parametrize(
@@ -1504,35 +1504,35 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
15041504
)
15051505
def test_zsn_support_shape(self, shape, support_shape):
15061506
with pm.Model() as m:
1507-
v = pm.ZeroSumNormal("v", shape=shape, support_shape=support_shape, zerosum_axes=2)
1507+
v = pm.ZeroSumNormal("v", shape=shape, support_shape=support_shape, n_zerosum_axes=2)
15081508

15091509
random_samples = pm.draw(v, draws=10)
1510-
zerosum_axes = np.arange(-2, 0)
1511-
self.assert_zerosum_axes(random_samples, zerosum_axes)
1510+
n_zerosum_axes = np.arange(-2, 0)
1511+
self.assert_zerosum_axes(random_samples, n_zerosum_axes)
15121512

15131513
@pytest.mark.parametrize(
1514-
"zerosum_axes",
1514+
"n_zerosum_axes",
15151515
[1, 2],
15161516
)
1517-
def test_zsn_change_dist_size(self, zerosum_axes):
1518-
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes)
1517+
def test_zsn_change_dist_size(self, n_zerosum_axes):
1518+
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), n_zerosum_axes=n_zerosum_axes)
15191519
random_samples = pm.draw(base_dist, draws=100)
15201520

1521-
zerosum_axes = np.arange(-zerosum_axes, 0)
1522-
self.assert_zerosum_axes(random_samples, zerosum_axes)
1521+
n_zerosum_axes = np.arange(-n_zerosum_axes, 0)
1522+
self.assert_zerosum_axes(random_samples, n_zerosum_axes)
15231523

15241524
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False)
15251525
try:
15261526
assert new_dist.eval().shape == (5, 3, 9)
15271527
except AssertionError:
15281528
assert new_dist.eval().shape == (5, 3, 4, 9)
15291529
random_samples = pm.draw(new_dist, draws=100)
1530-
self.assert_zerosum_axes(random_samples, zerosum_axes)
1530+
self.assert_zerosum_axes(random_samples, n_zerosum_axes)
15311531

15321532
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True)
15331533
assert new_dist.eval().shape == (5, 3, 4, 9)
15341534
random_samples = pm.draw(new_dist, draws=100)
1535-
self.assert_zerosum_axes(random_samples, zerosum_axes)
1535+
self.assert_zerosum_axes(random_samples, n_zerosum_axes)
15361536

15371537
@pytest.mark.parametrize(
15381538
"sigma, n",
@@ -1551,15 +1551,15 @@ def test_zsn_variance(self, sigma, n):
15511551
np.testing.assert_allclose(empirical_var, theoretical_var, atol=0.4)
15521552

15531553
@pytest.mark.parametrize(
1554-
"sigma, shape, zerosum_axes, mvn_axes",
1554+
"sigma, shape, n_zerosum_axes, mvn_axes",
15551555
[
15561556
(5, 3, None, [-1]),
15571557
(2, 6, None, [-1]),
15581558
(5, (7, 3), None, [-1]),
15591559
(5, (2, 7, 3), 2, [1, 2]),
15601560
],
15611561
)
1562-
def test_zsn_logp(self, sigma, shape, zerosum_axes, mvn_axes):
1562+
def test_zsn_logp(self, sigma, shape, n_zerosum_axes, mvn_axes):
15631563
def logp_norm(value, sigma, axes):
15641564
"""
15651565
Special case of the MvNormal, that's equivalent to the ZSN.
@@ -1588,7 +1588,7 @@ def logp_norm(value, sigma, axes):
15881588

15891589
return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf)
15901590

1591-
zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes)
1591+
zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, n_zerosum_axes=n_zerosum_axes)
15921592
zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval()
15931593
mvn_logp = logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes)
15941594

0 commit comments

Comments
 (0)