36
36
37
37
import pymc3 as pm
38
38
39
- from pymc3 .aesaraf import change_rv_size , floatX , intX
39
+ from pymc3 .aesaraf import floatX , intX
40
40
from pymc3 .distributions import (
41
41
AR1 ,
42
42
CAR ,
102
102
continuous ,
103
103
logcdf ,
104
104
logp ,
105
+ logpt ,
105
106
logpt_sum ,
106
107
)
107
- from pymc3 .math import kronecker , logsumexp
108
+ from pymc3 .math import kronecker
108
109
from pymc3 .model import Deterministic , Model , Point
109
110
from pymc3 .tests .helpers import select_by_precision
110
111
from pymc3 .vartypes import continuous_types
@@ -751,25 +752,33 @@ def check_logcdf(
751
752
if not skip_paramdomain_inside_edge_test :
752
753
domains = paramdomains .copy ()
753
754
domains ["value" ] = domain
755
+
756
+ model , param_vars = build_model (pymc3_dist , domain , paramdomains )
757
+ pymc3_logcdf = model .fastfn (logpt (model ["value" ], cdf = True ))
758
+
754
759
if decimal is None :
755
760
decimal = select_by_precision (float64 = 6 , float32 = 3 )
756
761
757
762
for pt in product (domains , n_samples = n_samples ):
758
763
params = dict (pt )
759
764
if skip_params_fn (params ):
760
765
continue
761
- scipy_cdf = scipy_logcdf (** params )
766
+
767
+ scipy_eval = scipy_logcdf (** params )
768
+
762
769
value = params .pop ("value" )
763
- with Model () as m :
764
- dist = pymc3_dist ("y" , ** params )
770
+ # Update shared parameter variables in pymc3_logcdf function
771
+ for param_name , param_value in params .items ():
772
+ param_vars [param_name ].set_value (param_value )
773
+ pymc3_eval = pymc3_logcdf ({"value" : value })
774
+
765
775
params ["value" ] = value # for displaying in err_msg
766
- with aesara .config .change_flags (on_opt_error = "raise" , mode = Mode ("py" )):
767
- assert_almost_equal (
768
- logcdf (dist , value ).eval (),
769
- scipy_cdf ,
770
- decimal = decimal ,
771
- err_msg = str (params ),
772
- )
776
+ assert_almost_equal (
777
+ pymc3_eval ,
778
+ scipy_eval ,
779
+ decimal = decimal ,
780
+ err_msg = str (params ),
781
+ )
773
782
774
783
valid_value = domain .vals [0 ]
775
784
valid_params = {param : paramdomain .vals [0 ] for param , paramdomain in paramdomains .items ()}
@@ -849,24 +858,33 @@ def check_selfconsistency_discrete_logcdf(
849
858
"""
850
859
Check that logcdf of discrete distributions matches sum of logps up to value
851
860
"""
861
+ # This test only works for scalar random variables
862
+ assert distribution .rv_op .ndim_supp == 0
863
+
852
864
domains = paramdomains .copy ()
853
865
domains ["value" ] = domain
854
866
if decimal is None :
855
867
decimal = select_by_precision (float64 = 6 , float32 = 3 )
868
+
869
+ model , param_vars = build_model (distribution , domain , paramdomains )
870
+ dist_logcdf = model .fastfn (logpt (model ["value" ], cdf = True ))
871
+ dist_logp = model .fastfn (logpt (model ["value" ]))
872
+
856
873
for pt in product (domains , n_samples = n_samples ):
857
874
params = dict (pt )
858
875
if skip_params_fn (params ):
859
876
continue
860
877
value = params .pop ("value" )
861
878
values = np .arange (domain .lower , value + 1 )
862
- dist = distribution .dist (** params )
863
- # This only works for scalar random variables
864
- assert dist .owner .op .ndim_supp == 0
865
- values_dist = change_rv_size (dist , values .shape )
879
+
880
+ # Update shared parameter variables in logp/logcdf function
881
+ for param_name , param_value in params .items ():
882
+ param_vars [param_name ].set_value (param_value )
883
+
866
884
with aesara .config .change_flags (mode = Mode ("py" )):
867
885
assert_almost_equal (
868
- logcdf ( dist , value ). eval ( ),
869
- logsumexp (logp ( values_dist , values ), keepdims = False ). eval ( ),
886
+ dist_logcdf ({ "value" : value } ),
887
+ scipy . special . logsumexp ([ dist_logp ({ "value" : value }) for value in values ] ),
870
888
decimal = decimal ,
871
889
err_msg = str (pt ),
872
890
)
@@ -1140,13 +1158,17 @@ def test_beta(self):
1140
1158
{"alpha" : Rplus , "beta" : Rplus },
1141
1159
lambda value , alpha , beta : sp .beta .logpdf (value , alpha , beta ),
1142
1160
)
1143
- self .check_logp (Beta , Unit , {"mu" : Unit , "sigma" : Rplus }, beta_mu_sigma )
1161
+ self .check_logp (
1162
+ Beta ,
1163
+ Unit ,
1164
+ {"mu" : Unit , "sigma" : Rplus },
1165
+ beta_mu_sigma ,
1166
+ )
1144
1167
self .check_logcdf (
1145
1168
Beta ,
1146
1169
Unit ,
1147
1170
{"alpha" : Rplus , "beta" : Rplus },
1148
1171
lambda value , alpha , beta : sp .beta .logcdf (value , alpha , beta ),
1149
- n_samples = 10 ,
1150
1172
decimal = select_by_precision (float64 = 5 , float32 = 3 ),
1151
1173
)
1152
1174
@@ -1269,20 +1291,17 @@ def scipy_mu_alpha_logcdf(value, mu, alpha):
1269
1291
Nat ,
1270
1292
{"mu" : Rplus , "alpha" : Rplus },
1271
1293
scipy_mu_alpha_logcdf ,
1272
- n_samples = 5 ,
1273
1294
)
1274
1295
self .check_logcdf (
1275
1296
NegativeBinomial ,
1276
1297
Nat ,
1277
1298
{"p" : Unit , "n" : Rplus },
1278
1299
lambda value , p , n : sp .nbinom .logcdf (value , n , p ),
1279
- n_samples = 5 ,
1280
1300
)
1281
1301
self .check_selfconsistency_discrete_logcdf (
1282
1302
NegativeBinomial ,
1283
1303
Nat ,
1284
1304
{"mu" : Rplus , "alpha" : Rplus },
1285
- n_samples = 10 ,
1286
1305
)
1287
1306
1288
1307
@pytest .mark .parametrize (
@@ -1340,7 +1359,6 @@ def test_lognormal(self):
1340
1359
Rplus ,
1341
1360
{"mu" : R , "sigma" : Rplusbig },
1342
1361
lambda value , mu , sigma : floatX (sp .lognorm .logpdf (value , sigma , 0 , np .exp (mu ))),
1343
- n_samples = 5 , # Just testing alternative parametrization
1344
1362
)
1345
1363
self .check_logcdf (
1346
1364
Lognormal ,
@@ -1353,7 +1371,6 @@ def test_lognormal(self):
1353
1371
Rplus ,
1354
1372
{"mu" : R , "sigma" : Rplusbig },
1355
1373
lambda value , mu , sigma : sp .lognorm .logcdf (value , sigma , 0 , np .exp (mu )),
1356
- n_samples = 5 , # Just testing alternative parametrization
1357
1374
)
1358
1375
1359
1376
def test_t (self ):
@@ -1368,14 +1385,12 @@ def test_t(self):
1368
1385
R ,
1369
1386
{"nu" : Rplus , "mu" : R , "sigma" : Rplus },
1370
1387
lambda value , nu , mu , sigma : sp .t .logpdf (value , nu , mu , sigma ),
1371
- n_samples = 5 , # Just testing alternative parametrization
1372
1388
)
1373
1389
self .check_logcdf (
1374
1390
StudentT ,
1375
1391
R ,
1376
1392
{"nu" : Rplus , "mu" : R , "lam" : Rplus },
1377
1393
lambda value , nu , mu , lam : sp .t .logcdf (value , nu , mu , lam ** - 0.5 ),
1378
- n_samples = 10 , # relies on slow incomplete beta
1379
1394
)
1380
1395
# TODO: reenable when PR #4736 is merged
1381
1396
"""
@@ -1384,7 +1399,6 @@ def test_t(self):
1384
1399
R,
1385
1400
{"nu": Rplus, "mu": R, "sigma": Rplus},
1386
1401
lambda value, nu, mu, sigma: sp.t.logcdf(value, nu, mu, sigma),
1387
- n_samples=5, # Just testing alternative parametrization
1388
1402
)
1389
1403
"""
1390
1404
@@ -1561,13 +1575,11 @@ def test_binomial(self):
1561
1575
Nat ,
1562
1576
{"n" : NatSmall , "p" : Unit },
1563
1577
lambda value , n , p : sp .binom .logcdf (value , n , p ),
1564
- n_samples = 10 ,
1565
1578
)
1566
1579
self .check_selfconsistency_discrete_logcdf (
1567
1580
Binomial ,
1568
1581
Nat ,
1569
1582
{"n" : NatSmall , "p" : Unit },
1570
- n_samples = 10 ,
1571
1583
)
1572
1584
1573
1585
@pytest .mark .xfail (reason = "checkd tests has not been refactored" )
@@ -1769,14 +1781,12 @@ def logcdf_fn(value, psi, mu, alpha):
1769
1781
Nat ,
1770
1782
{"psi" : Unit , "mu" : Rplusbig , "alpha" : Rplusbig },
1771
1783
logcdf_fn ,
1772
- n_samples = 10 ,
1773
1784
)
1774
1785
1775
1786
self .check_selfconsistency_discrete_logcdf (
1776
1787
ZeroInflatedNegativeBinomial ,
1777
1788
Nat ,
1778
1789
{"psi" : Unit , "mu" : Rplusbig , "alpha" : Rplusbig },
1779
- n_samples = 10 ,
1780
1790
)
1781
1791
1782
1792
@pytest .mark .xfail (reason = "Test not refactored yet" )
@@ -1809,14 +1819,12 @@ def logcdf_fn(value, psi, n, p):
1809
1819
Nat ,
1810
1820
{"psi" : Unit , "n" : NatSmall , "p" : Unit },
1811
1821
logcdf_fn ,
1812
- n_samples = 10 ,
1813
1822
)
1814
1823
1815
1824
self .check_selfconsistency_discrete_logcdf (
1816
1825
ZeroInflatedBinomial ,
1817
1826
Nat ,
1818
1827
{"n" : NatSmall , "p" : Unit , "psi" : Unit },
1819
- n_samples = 10 ,
1820
1828
)
1821
1829
1822
1830
@pytest .mark .parametrize ("n" , [1 , 2 , 3 ])
0 commit comments