32
32
33
33
from pymc3 .aesaraf import change_rv_size , floatX , intX
34
34
from pymc3 .distributions .continuous import get_tau_sigma
35
+ from pymc3 .distributions .dist_math import clipped_beta_rvs
35
36
from pymc3 .distributions .multivariate import quaddist_matrix
36
37
from pymc3 .distributions .shape_utils import to_tuple
37
38
from pymc3 .exceptions import ShapeError
@@ -278,11 +279,6 @@ class TestWald(BaseTestCases.BaseTestCase):
278
279
params = {"mu" : 1.0 , "lam" : 1.0 , "alpha" : 0.0 }
279
280
280
281
281
- class TestBeta (BaseTestCases .BaseTestCase ):
282
- distribution = pm .Beta
283
- params = {"alpha" : 1.0 , "beta" : 1.0 }
284
-
285
-
286
282
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
287
283
class TestKumaraswamy (BaseTestCases .BaseTestCase ):
288
284
distribution = pm .Kumaraswamy
@@ -355,11 +351,6 @@ class TestBetaBinomial(BaseTestCases.BaseTestCase):
355
351
params = {"n" : 5 , "alpha" : 1.0 , "beta" : 1.0 }
356
352
357
353
358
- class TestDiscreteWeibull (BaseTestCases .BaseTestCase ):
359
- distribution = pm .DiscreteWeibull
360
- params = {"q" : 0.25 , "beta" : 2.0 }
361
-
362
-
363
354
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
364
355
class TestConstant (BaseTestCases .BaseTestCase ):
365
356
distribution = pm .Constant
@@ -426,17 +417,10 @@ def test_distribution(self):
426
417
self ._instantiate_pymc_rv ()
427
418
if self .reference_dist is not None :
428
419
self .reference_dist_draws = self .reference_dist ()(
429
- ** self .reference_dist_params , size = self .size
420
+ size = self .size , ** self .reference_dist_params
430
421
)
431
- for test_name in self .tests_to_run :
432
- self .run_test (test_name )
433
-
434
- def run_test (self , test_name ):
435
- {
436
- "check_pymc_dist_matches_reference" : self ._check_pymc_draws_match_reference ,
437
- "check_pymc_params_match_rv_op" : self ._check_pymc_params_match_rv_op ,
438
- "check_rv_size" : self ._check_rv_size ,
439
- }[test_name ]()
422
+ for check_name in self .tests_to_run :
423
+ getattr (self , check_name )()
440
424
441
425
def _instantiate_pymc_rv (self , dist_params = None ):
442
426
params = dist_params if dist_params else self .pymc_dist_params
@@ -448,25 +432,22 @@ def _instantiate_pymc_rv(self, dist_params=None):
448
432
name = f"{ self .pymc_dist .rv_op .name } _test" ,
449
433
)
450
434
451
- def _check_pymc_draws_match_reference (self ):
435
+ def check_pymc_draws_match_reference (self ):
452
436
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
453
437
self ._instantiate_pymc_rv ()
454
438
assert_array_almost_equal (
455
439
self .pymc_rv .eval (), self .reference_dist_draws , decimal = self .decimal
456
440
)
457
441
458
- def _check_pymc_params_match_rv_op (self ) -> None :
459
- try :
460
- aesera_dist_inputs = self .pymc_rv .get_parents ()[0 ].inputs [3 :]
461
- except :
462
- raise Exception ("Parent Apply node missing from output" )
442
+ def check_pymc_params_match_rv_op (self ) -> None :
443
+ aesera_dist_inputs = self .pymc_rv .get_parents ()[0 ].inputs [3 :]
463
444
assert len (self .expected_rv_op_params ) == len (aesera_dist_inputs )
464
445
for (expected_name , expected_value ), actual_variable in zip (
465
446
self .expected_rv_op_params .items (), aesera_dist_inputs
466
447
):
467
448
assert_almost_equal (expected_value , actual_variable .eval (), decimal = self .decimal )
468
449
469
- def _check_rv_size (self ):
450
+ def check_rv_size (self ):
470
451
# test sizes
471
452
sizes_to_check = self .sizes_to_check or [None , (), 1 , (1 ,), 5 , (4 , 5 ), (2 , 4 , 2 )]
472
453
sizes_expected = self .sizes_expected or [(), (), (1 ,), (1 ,), (5 ,), (4 , 5 ), (2 , 4 , 2 )]
@@ -508,6 +489,28 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
508
489
)
509
490
510
491
492
+ class TestDiscreteWeibull (BaseTestDistribution ):
493
+ def discrete_weibul_rng_fn (self ):
494
+ p = seeded_numpy_distribution_builder ("uniform" )
495
+ return (
496
+ lambda size , q , beta : np .ceil (
497
+ np .power (np .log (1 - p (self )(size = size )) / np .log (q ), 1.0 / beta )
498
+ )
499
+ - 1
500
+ )
501
+
502
+ pymc_dist = pm .DiscreteWeibull
503
+ pymc_dist_params = {"q" : 0.25 , "beta" : 2.0 }
504
+ expected_rv_op_params = {"q" : 0.25 , "beta" : 2.0 }
505
+ reference_dist_params = {"q" : 0.25 , "beta" : 2.0 }
506
+ reference_dist = discrete_weibul_rng_fn
507
+ tests_to_run = [
508
+ "check_pymc_params_match_rv_op" ,
509
+ "check_rv_size" ,
510
+ "check_pymc_dist_matches_reference" ,
511
+ ]
512
+
513
+
511
514
class TestGumbel (BaseTestDistribution ):
512
515
pymc_dist = pm .Gumbel
513
516
pymc_dist_params = {"mu" : 1.5 , "beta" : 3.0 }
@@ -584,7 +587,16 @@ class TestBeta(BaseTestDistribution):
584
587
pymc_dist = pm .Beta
585
588
pymc_dist_params = {"alpha" : 2.0 , "beta" : 5.0 }
586
589
expected_rv_op_params = {"alpha" : 2.0 , "beta" : 5.0 }
587
- tests_to_run = ["check_pymc_params_match_rv_op" ]
590
+ reference_dist_params = {"a" : 2.0 , "b" : 5.0 }
591
+ size = 15
592
+ reference_dist = lambda self : functools .partial (
593
+ clipped_beta_rvs , random_state = self .get_random_state ()
594
+ )
595
+ tests_to_run = [
596
+ "check_pymc_params_match_rv_op" ,
597
+ "check_rv_size" ,
598
+ "check_pymc_params_match_rv_op" ,
599
+ ]
588
600
589
601
590
602
class TestBetaMuSigma (BaseTestDistribution ):
0 commit comments