@@ -1014,16 +1014,16 @@ def test_mv_normal_moment(self, mu, cov, size, expected):
1014
1014
assert_moment_is_expected (model , expected , check_finite_logp = x .ndim < 3 )
1015
1015
1016
1016
@pytest .mark .parametrize (
1017
- "shape, zerosum_axes , expected" ,
1017
+ "shape, n_zerosum_axes , expected" ,
1018
1018
[
1019
1019
((2 , 5 ), None , np .zeros ((2 , 5 ))),
1020
1020
((2 , 5 , 6 ), 2 , np .zeros ((2 , 5 , 6 ))),
1021
1021
((2 , 5 , 6 ), 3 , np .zeros ((2 , 5 , 6 ))),
1022
1022
],
1023
1023
)
1024
- def test_zerosum_normal_moment (self , shape , zerosum_axes , expected ):
1024
+ def test_zerosum_normal_moment (self , shape , n_zerosum_axes , expected ):
1025
1025
with pm .Model () as model :
1026
- pm .ZeroSumNormal ("x" , shape = shape , zerosum_axes = zerosum_axes )
1026
+ pm .ZeroSumNormal ("x" , shape = shape , n_zerosum_axes = n_zerosum_axes )
1027
1027
assert_moment_is_expected (model , expected )
1028
1028
1029
1029
@pytest .mark .parametrize (
@@ -1405,16 +1405,16 @@ def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=
1405
1405
).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1406
1406
1407
1407
@pytest .mark .parametrize (
1408
- "dims, zerosum_axes " ,
1408
+ "dims, n_zerosum_axes " ,
1409
1409
[
1410
1410
(("regions" , "answers" ), None ),
1411
1411
(("regions" , "answers" ), 1 ),
1412
1412
(("regions" , "answers" ), 2 ),
1413
1413
],
1414
1414
)
1415
- def test_zsn_dims (self , dims , zerosum_axes ):
1415
+ def test_zsn_dims (self , dims , n_zerosum_axes ):
1416
1416
with pm .Model (coords = self .coords ) as m :
1417
- v = pm .ZeroSumNormal ("v" , dims = dims , zerosum_axes = zerosum_axes )
1417
+ v = pm .ZeroSumNormal ("v" , dims = dims , n_zerosum_axes = n_zerosum_axes )
1418
1418
s = pm .sample (10 , chains = 1 , tune = 100 )
1419
1419
1420
1420
# to test forward graph
@@ -1428,24 +1428,24 @@ def test_zsn_dims(self, dims, zerosum_axes):
1428
1428
)
1429
1429
1430
1430
ndim_supp = v .owner .op .ndim_supp
1431
- zerosum_axes = np .arange (- ndim_supp , 0 )
1431
+ n_zerosum_axes = np .arange (- ndim_supp , 0 )
1432
1432
nonzero_axes = np .arange (v .ndim - ndim_supp )
1433
1433
for samples in [
1434
1434
s .posterior .v ,
1435
1435
random_samples ,
1436
1436
]:
1437
- self .assert_zerosum_axes (samples , zerosum_axes )
1437
+ self .assert_zerosum_axes (samples , n_zerosum_axes )
1438
1438
self .assert_zerosum_axes (samples , nonzero_axes , check_zerosum_axes = False )
1439
1439
1440
1440
@pytest .mark .parametrize (
1441
- "zerosum_axes " ,
1441
+ "n_zerosum_axes " ,
1442
1442
(None , 1 , 2 ),
1443
1443
)
1444
- def test_zsn_shape (self , zerosum_axes ):
1444
+ def test_zsn_shape (self , n_zerosum_axes ):
1445
1445
shape = (len (self .coords ["regions" ]), len (self .coords ["answers" ]))
1446
1446
1447
1447
with pm .Model (coords = self .coords ) as m :
1448
- v = pm .ZeroSumNormal ("v" , shape = shape , zerosum_axes = zerosum_axes )
1448
+ v = pm .ZeroSumNormal ("v" , shape = shape , n_zerosum_axes = n_zerosum_axes )
1449
1449
s = pm .sample (10 , chains = 1 , tune = 100 )
1450
1450
1451
1451
# to test forward graph
@@ -1459,17 +1459,17 @@ def test_zsn_shape(self, zerosum_axes):
1459
1459
)
1460
1460
1461
1461
ndim_supp = v .owner .op .ndim_supp
1462
- zerosum_axes = np .arange (- ndim_supp , 0 )
1462
+ n_zerosum_axes = np .arange (- ndim_supp , 0 )
1463
1463
nonzero_axes = np .arange (v .ndim - ndim_supp )
1464
1464
for samples in [
1465
1465
s .posterior .v ,
1466
1466
random_samples ,
1467
1467
]:
1468
- self .assert_zerosum_axes (samples , zerosum_axes )
1468
+ self .assert_zerosum_axes (samples , n_zerosum_axes )
1469
1469
self .assert_zerosum_axes (samples , nonzero_axes , check_zerosum_axes = False )
1470
1470
1471
1471
@pytest .mark .parametrize (
1472
- "error, match, shape, support_shape, zerosum_axes " ,
1472
+ "error, match, shape, support_shape, n_zerosum_axes " ,
1473
1473
[
1474
1474
(
1475
1475
ValueError ,
@@ -1485,14 +1485,14 @@ def test_zsn_shape(self, zerosum_axes):
1485
1485
(3 , 4 ),
1486
1486
(3 , 4 ),
1487
1487
None ,
1488
- ), # doesn't work because zerosum_axes = 1 by default
1488
+ ), # doesn't work because n_zerosum_axes = 1 by default
1489
1489
],
1490
1490
)
1491
- def test_zsn_fail_axis (self , error , match , shape , support_shape , zerosum_axes ):
1491
+ def test_zsn_fail_axis (self , error , match , shape , support_shape , n_zerosum_axes ):
1492
1492
with pytest .raises (error , match = match ):
1493
1493
with pm .Model () as m :
1494
1494
_ = pm .ZeroSumNormal (
1495
- "v" , shape = shape , support_shape = support_shape , zerosum_axes = zerosum_axes
1495
+ "v" , shape = shape , support_shape = support_shape , n_zerosum_axes = n_zerosum_axes
1496
1496
)
1497
1497
1498
1498
@pytest .mark .parametrize (
@@ -1504,35 +1504,35 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
1504
1504
)
1505
1505
def test_zsn_support_shape (self , shape , support_shape ):
1506
1506
with pm .Model () as m :
1507
- v = pm .ZeroSumNormal ("v" , shape = shape , support_shape = support_shape , zerosum_axes = 2 )
1507
+ v = pm .ZeroSumNormal ("v" , shape = shape , support_shape = support_shape , n_zerosum_axes = 2 )
1508
1508
1509
1509
random_samples = pm .draw (v , draws = 10 )
1510
- zerosum_axes = np .arange (- 2 , 0 )
1511
- self .assert_zerosum_axes (random_samples , zerosum_axes )
1510
+ n_zerosum_axes = np .arange (- 2 , 0 )
1511
+ self .assert_zerosum_axes (random_samples , n_zerosum_axes )
1512
1512
1513
1513
@pytest .mark .parametrize (
1514
- "zerosum_axes " ,
1514
+ "n_zerosum_axes " ,
1515
1515
[1 , 2 ],
1516
1516
)
1517
- def test_zsn_change_dist_size (self , zerosum_axes ):
1518
- base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), zerosum_axes = zerosum_axes )
1517
+ def test_zsn_change_dist_size (self , n_zerosum_axes ):
1518
+ base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), n_zerosum_axes = n_zerosum_axes )
1519
1519
random_samples = pm .draw (base_dist , draws = 100 )
1520
1520
1521
- zerosum_axes = np .arange (- zerosum_axes , 0 )
1522
- self .assert_zerosum_axes (random_samples , zerosum_axes )
1521
+ n_zerosum_axes = np .arange (- n_zerosum_axes , 0 )
1522
+ self .assert_zerosum_axes (random_samples , n_zerosum_axes )
1523
1523
1524
1524
new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
1525
1525
try :
1526
1526
assert new_dist .eval ().shape == (5 , 3 , 9 )
1527
1527
except AssertionError :
1528
1528
assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
1529
1529
random_samples = pm .draw (new_dist , draws = 100 )
1530
- self .assert_zerosum_axes (random_samples , zerosum_axes )
1530
+ self .assert_zerosum_axes (random_samples , n_zerosum_axes )
1531
1531
1532
1532
new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = True )
1533
1533
assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
1534
1534
random_samples = pm .draw (new_dist , draws = 100 )
1535
- self .assert_zerosum_axes (random_samples , zerosum_axes )
1535
+ self .assert_zerosum_axes (random_samples , n_zerosum_axes )
1536
1536
1537
1537
@pytest .mark .parametrize (
1538
1538
"sigma, n" ,
@@ -1551,15 +1551,15 @@ def test_zsn_variance(self, sigma, n):
1551
1551
np .testing .assert_allclose (empirical_var , theoretical_var , atol = 0.4 )
1552
1552
1553
1553
@pytest .mark .parametrize (
1554
- "sigma, shape, zerosum_axes , mvn_axes" ,
1554
+ "sigma, shape, n_zerosum_axes , mvn_axes" ,
1555
1555
[
1556
1556
(5 , 3 , None , [- 1 ]),
1557
1557
(2 , 6 , None , [- 1 ]),
1558
1558
(5 , (7 , 3 ), None , [- 1 ]),
1559
1559
(5 , (2 , 7 , 3 ), 2 , [1 , 2 ]),
1560
1560
],
1561
1561
)
1562
- def test_zsn_logp (self , sigma , shape , zerosum_axes , mvn_axes ):
1562
+ def test_zsn_logp (self , sigma , shape , n_zerosum_axes , mvn_axes ):
1563
1563
def logp_norm (value , sigma , axes ):
1564
1564
"""
1565
1565
Special case of the MvNormal, that's equivalent to the ZSN.
@@ -1588,7 +1588,7 @@ def logp_norm(value, sigma, axes):
1588
1588
1589
1589
return np .where (inds , np .sum (- psdet - exp , axis = - 1 ), - np .inf )
1590
1590
1591
- zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , zerosum_axes = zerosum_axes )
1591
+ zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , n_zerosum_axes = n_zerosum_axes )
1592
1592
zsn_logp = pm .logp (zsn_dist , value = np .zeros (shape )).eval ()
1593
1593
mvn_logp = logp_norm (value = np .zeros (shape ), sigma = sigma , axes = mvn_axes )
1594
1594
0 commit comments