@@ -414,6 +414,7 @@ class BaseTestDistribution(SeededTest):
414
414
repeated_params_shape = 5
415
415
416
416
def test_distribution (self ):
417
+ self .validate_tests_list ()
417
418
self ._instantiate_pymc_rv ()
418
419
if self .reference_dist is not None :
419
420
self .reference_dist_draws = self .reference_dist ()(
@@ -439,7 +440,7 @@ def check_pymc_draws_match_reference(self):
439
440
self .pymc_rv .eval (), self .reference_dist_draws , decimal = self .decimal
440
441
)
441
442
442
- def check_pymc_params_match_rv_op (self ) -> None :
443
+ def check_pymc_params_match_rv_op (self ):
443
444
aesera_dist_inputs = self .pymc_rv .get_parents ()[0 ].inputs [3 :]
444
445
assert len (self .expected_rv_op_params ) == len (aesera_dist_inputs )
445
446
for (expected_name , expected_value ), actual_variable in zip (
@@ -476,6 +477,11 @@ def check_rv_size(self):
476
477
actual = change_rv_size (self .pymc_rv , size ).eval ().shape
477
478
assert actual == expected
478
479
480
+ def validate_tests_list (self ):
481
+ assert len (self .tests_to_run ) == len (
482
+ set (self .tests_to_run )
483
+ ), "There are duplicates in the list of tests_to_run"
484
+
479
485
480
486
def seeded_scipy_distribution_builder (dist_name : str ) -> Callable :
481
487
return lambda self : functools .partial (
@@ -490,24 +496,24 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
490
496
491
497
492
498
class TestDiscreteWeibull (BaseTestDistribution ):
493
- def discrete_weibul_rng_fn (self ):
494
- p = seeded_numpy_distribution_builder ("uniform" )
495
- return (
496
- lambda size , q , beta : np .ceil (
497
- np .power (np .log (1 - p (self )(size = size )) / np .log (q ), 1.0 / beta )
498
- )
499
- - 1
499
+ def discrete_weibul_rng_fn (self , size , q , beta , uniform_rng_fct ):
500
+ return np .ceil (np .power (np .log (1 - uniform_rng_fct (size = size )) / np .log (q ), 1.0 / beta )) - 1
501
+
502
+ def seeded_discrete_weibul_rng_fn (self ):
503
+ uniform_rng_fct = functools .partial (
504
+ getattr (np .random .RandomState , "uniform" ), self .get_random_state ()
500
505
)
506
+ return functools .partial (self .discrete_weibul_rng_fn , uniform_rng_fct = uniform_rng_fct )
501
507
502
508
pymc_dist = pm .DiscreteWeibull
503
509
pymc_dist_params = {"q" : 0.25 , "beta" : 2.0 }
504
510
expected_rv_op_params = {"q" : 0.25 , "beta" : 2.0 }
505
511
reference_dist_params = {"q" : 0.25 , "beta" : 2.0 }
506
- reference_dist = discrete_weibul_rng_fn
512
+ reference_dist = seeded_discrete_weibul_rng_fn
507
513
tests_to_run = [
508
514
"check_pymc_params_match_rv_op" ,
509
515
"check_rv_size" ,
510
- "check_pymc_dist_matches_reference " ,
516
+ "check_pymc_draws_match_reference " ,
511
517
]
512
518
513
519
@@ -521,7 +527,7 @@ class TestGumbel(BaseTestDistribution):
521
527
tests_to_run = [
522
528
"check_pymc_params_match_rv_op" ,
523
529
"check_rv_size" ,
524
- "check_pymc_dist_matches_reference " ,
530
+ "check_pymc_draws_match_reference " ,
525
531
]
526
532
527
533
@@ -535,7 +541,7 @@ class TestNormal(BaseTestDistribution):
535
541
tests_to_run = [
536
542
"check_pymc_params_match_rv_op" ,
537
543
"check_rv_size" ,
538
- "check_pymc_dist_matches_reference " ,
544
+ "check_pymc_draws_match_reference " ,
539
545
]
540
546
541
547
@@ -595,7 +601,7 @@ class TestBeta(BaseTestDistribution):
595
601
tests_to_run = [
596
602
"check_pymc_params_match_rv_op" ,
597
603
"check_rv_size" ,
598
- "check_pymc_params_match_rv_op " ,
604
+ "check_pymc_draws_match_reference " ,
599
605
]
600
606
601
607
0 commit comments