24
24
import pytest
25
25
import scipy .stats as st
26
26
27
+ from numpy .testing import assert_almost_equal
27
28
from scipy import linalg
28
29
from scipy .special import expit
29
30
30
31
import pymc3 as pm
31
32
32
33
from pymc3 .aesaraf import change_rv_size , floatX , intX
33
- from pymc3 .distributions .dist_math import clipped_beta_rvs
34
34
from pymc3 .distributions .shape_utils import to_tuple
35
35
from pymc3 .exceptions import ShapeError
36
36
from pymc3 .tests .helpers import SeededTest , select_by_precision
@@ -524,6 +524,76 @@ def test_dirichlet_random_shape(self, shape, size):
524
524
assert pm .Dirichlet .dist (a = np .ones (shape )).random (size = size ).shape == out_shape
525
525
526
526
527
+ class TestCorrectParametrizationMappingPymcToScipy (SeededTest ):
528
+ @staticmethod
529
+ def get_inputs_from_apply_node_outputs (outputs ):
530
+ parents = outputs .get_parents ()
531
+ if not parents :
532
+ raise Exception ("Parent Apply node missing for output" )
533
+ # I am assuming there will always only be 1 Apply parent node in this context
534
+ return parents [0 ].inputs
535
+
536
+ def test_pymc_params_match_rv_ones (
537
+ self , pymc_params , expected_aesara_params , pymc_dist , decimal = 6
538
+ ):
539
+ pymc_dist_output = pymc_dist .dist (** dict (pymc_params ))
540
+ aesera_dist_inputs = self .get_inputs_from_apply_node_outputs (pymc_dist_output )[3 :]
541
+ assert len (expected_aesara_params ) == len (aesera_dist_inputs )
542
+ for (expected_name , expected_value ), actual_variable in zip (
543
+ expected_aesara_params , aesera_dist_inputs
544
+ ):
545
+ assert_almost_equal (expected_value , actual_variable .eval (), decimal = decimal )
546
+
547
+ def test_normal (self ):
548
+ params = [("mu" , 5.0 ), ("sigma" , 10.0 )]
549
+ self .test_pymc_params_match_rv_ones (params , params , pm .Normal )
550
+
551
+ def test_uniform (self ):
552
+ params = [("lower" , 0.5 ), ("upper" , 1.5 )]
553
+ self .test_pymc_params_match_rv_ones (params , params , pm .Uniform )
554
+
555
+ def test_half_normal (self ):
556
+ params , expected_aesara_params = [("sigma" , 10.0 )], [("mean" , 0 ), ("sigma" , 10.0 )]
557
+ self .test_pymc_params_match_rv_ones (params , expected_aesara_params , pm .HalfNormal )
558
+
559
+ def test_beta_alpha_beta (self ):
560
+ params = [("alpha" , 2.0 ), ("beta" , 5.0 )]
561
+ self .test_pymc_params_match_rv_ones (params , params , pm .Beta )
562
+
563
+ def test_beta_mu_sigma (self ):
564
+ params = [("mu" , 2.0 ), ("sigma" , 5.0 )]
565
+ expected_alpha , expected_beta = pm .Beta .get_alpha_beta (mu = params [0 ][1 ], sigma = params [1 ][1 ])
566
+ expected_params = [("alpha" , expected_alpha ), ("beta" , expected_beta )]
567
+ self .test_pymc_params_match_rv_ones (params , expected_params , pm .Beta )
568
+
569
+ @pytest .mark .skip (reason = "Expected to fail due to bug" )
570
+ def test_exponential (self ):
571
+ params = [("lam" , 10.0 )]
572
+ expected_params = [("lam" , 1 / params [0 ][1 ])]
573
+ self .test_pymc_params_match_rv_ones (params , expected_params , pm .Exponential )
574
+
575
+ def test_cauchy (self ):
576
+ params = [("alpha" , 2.0 ), ("beta" , 5.0 )]
577
+ self .test_pymc_params_match_rv_ones (params , params , pm .Cauchy )
578
+
579
+ def test_half_cauchy (self ):
580
+ params = [("alpha" , 2.0 ), ("beta" , 5.0 )]
581
+ self .test_pymc_params_match_rv_ones (params , params , pm .HalfCauchy )
582
+
583
+ @pytest .mark .skip (reason = "Expected to fail due to bug" )
584
+ def test_gamma_alpha_beta (self ):
585
+ params = [("alpha" , 2.0 ), ("beta" , 5.0 )]
586
+ expected_params = [("alpha" , params [0 ][1 ]), ("beta" , 1 / params [1 ][1 ])]
587
+ self .test_pymc_params_match_rv_ones (params , expected_params , pm .Gamma )
588
+
589
+ @pytest .mark .skip (reason = "Expected to fail due to bug" )
590
+ def test_gamma_mu_sigma (self ):
591
+ params = [("mu" , 2.0 ), ("sigma" , 5.0 )]
592
+ expected_alpha , expected_beta = pm .Gamma .get_alpha_beta (mu = params [0 ][1 ], sigma = params [1 ][1 ])
593
+ expected_params = [("alpha" , expected_alpha ), ("beta" , 1 / expected_beta )]
594
+ self .test_pymc_params_match_rv_ones (params , expected_params , pm .Gamma )
595
+
596
+
527
597
class TestScalarParameterSamples (SeededTest ):
528
598
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
529
599
def test_bounded (self ):
@@ -535,20 +605,6 @@ def ref_rand(size, tau):
535
605
536
606
pymc3_random (BoundedNormal , {"tau" : Rplus }, ref_rand = ref_rand )
537
607
538
- @pytest .mark .skip (reason = "This test is covered by Aesara" )
539
- def test_uniform (self ):
540
- def ref_rand (size , lower , upper ):
541
- return st .uniform .rvs (size = size , loc = lower , scale = upper - lower )
542
-
543
- pymc3_random (pm .Uniform , {"lower" : - Rplus , "upper" : Rplus }, ref_rand = ref_rand )
544
-
545
- @pytest .mark .skip (reason = "This test is covered by Aesara" )
546
- def test_normal (self ):
547
- def ref_rand (size , mu , sigma ):
548
- return st .norm .rvs (size = size , loc = mu , scale = sigma )
549
-
550
- pymc3_random (pm .Normal , {"mu" : R , "sigma" : Rplus }, ref_rand = ref_rand )
551
-
552
608
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
553
609
def test_truncated_normal (self ):
554
610
def ref_rand (size , mu , sigma , lower , upper ):
@@ -587,13 +643,6 @@ def ref_rand(size, alpha, mu, sigma):
587
643
588
644
pymc3_random (pm .SkewNormal , {"mu" : R , "sigma" : Rplus , "alpha" : R }, ref_rand = ref_rand )
589
645
590
- @pytest .mark .skip (reason = "This test is covered by Aesara" )
591
- def test_half_normal (self ):
592
- def ref_rand (size , tau ):
593
- return st .halfnorm .rvs (size = size , loc = 0 , scale = tau ** - 0.5 )
594
-
595
- pymc3_random (pm .HalfNormal , {"tau" : Rplus }, ref_rand = ref_rand )
596
-
597
646
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
598
647
def test_wald (self ):
599
648
# Cannot do anything too exciting as scipy wald is a
@@ -607,13 +656,6 @@ def ref_rand(size, mu, lam, alpha):
607
656
ref_rand = ref_rand ,
608
657
)
609
658
610
- @pytest .mark .skip (reason = "This test is covered by Aesara" )
611
- def test_beta (self ):
612
- def ref_rand (size , alpha , beta ):
613
- return clipped_beta_rvs (a = alpha , b = beta , size = size )
614
-
615
- pymc3_random (pm .Beta , {"alpha" : Rplus , "beta" : Rplus }, ref_rand = ref_rand )
616
-
617
659
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
618
660
def test_laplace (self ):
619
661
def ref_rand (size , mu , b ):
@@ -648,20 +690,7 @@ def ref_rand(size, nu, mu, lam):
648
690
pymc3_random (pm .StudentT , {"nu" : Rplus , "mu" : R , "lam" : Rplus }, ref_rand = ref_rand )
649
691
650
692
@pytest .mark .skip (reason = "This test is covered by Aesara" )
651
- def test_cauchy (self ):
652
- def ref_rand (size , alpha , beta ):
653
- return st .cauchy .rvs (alpha , beta , size = size )
654
-
655
- pymc3_random (pm .Cauchy , {"alpha" : R , "beta" : Rplusbig }, ref_rand = ref_rand )
656
693
657
- @pytest .mark .skip (reason = "This test is covered by Aesara" )
658
- def test_half_cauchy (self ):
659
- def ref_rand (size , beta ):
660
- return st .halfcauchy .rvs (scale = beta , size = size )
661
-
662
- pymc3_random (pm .HalfCauchy , {"beta" : Rplusbig }, ref_rand = ref_rand )
663
-
664
- @pytest .mark .skip (reason = "This test is covered by Aesara" )
665
694
def test_inverse_gamma (self ):
666
695
def ref_rand (size , alpha , beta ):
667
696
return st .invgamma .rvs (a = alpha , scale = beta , size = size )
0 commit comments