Skip to content

Commit 9b52e1b

Browse files
Fix typos in checks naming and add sanity check
1 parent 3d28087 commit 9b52e1b

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ class BaseTestDistribution(SeededTest):
414414
repeated_params_shape = 5
415415

416416
def test_distribution(self):
417+
self.validate_tests_list()
417418
self._instantiate_pymc_rv()
418419
if self.reference_dist is not None:
419420
self.reference_dist_draws = self.reference_dist()(
@@ -439,7 +440,7 @@ def check_pymc_draws_match_reference(self):
439440
self.pymc_rv.eval(), self.reference_dist_draws, decimal=self.decimal
440441
)
441442

442-
def check_pymc_params_match_rv_op(self) -> None:
443+
def check_pymc_params_match_rv_op(self):
443444
aesera_dist_inputs = self.pymc_rv.get_parents()[0].inputs[3:]
444445
assert len(self.expected_rv_op_params) == len(aesera_dist_inputs)
445446
for (expected_name, expected_value), actual_variable in zip(
@@ -476,6 +477,11 @@ def check_rv_size(self):
476477
actual = change_rv_size(self.pymc_rv, size).eval().shape
477478
assert actual == expected
478479

480+
def validate_tests_list(self):
481+
assert len(self.tests_to_run) == len(
482+
set(self.tests_to_run)
483+
), "There are duplicates in the list of tests_to_run"
484+
479485

480486
def seeded_scipy_distribution_builder(dist_name: str) -> Callable:
481487
return lambda self: functools.partial(
@@ -490,24 +496,24 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
490496

491497

492498
class TestDiscreteWeibull(BaseTestDistribution):
493-
def discrete_weibul_rng_fn(self):
494-
p = seeded_numpy_distribution_builder("uniform")
495-
return (
496-
lambda size, q, beta: np.ceil(
497-
np.power(np.log(1 - p(self)(size=size)) / np.log(q), 1.0 / beta)
498-
)
499-
- 1
499+
def discrete_weibul_rng_fn(self, size, q, beta, uniform_rng_fct):
500+
return np.ceil(np.power(np.log(1 - uniform_rng_fct(size=size)) / np.log(q), 1.0 / beta)) - 1
501+
502+
def seeded_discrete_weibul_rng_fn(self):
503+
uniform_rng_fct = functools.partial(
504+
getattr(np.random.RandomState, "uniform"), self.get_random_state()
500505
)
506+
return functools.partial(self.discrete_weibul_rng_fn, uniform_rng_fct=uniform_rng_fct)
501507

502508
pymc_dist = pm.DiscreteWeibull
503509
pymc_dist_params = {"q": 0.25, "beta": 2.0}
504510
expected_rv_op_params = {"q": 0.25, "beta": 2.0}
505511
reference_dist_params = {"q": 0.25, "beta": 2.0}
506-
reference_dist = discrete_weibul_rng_fn
512+
reference_dist = seeded_discrete_weibul_rng_fn
507513
tests_to_run = [
508514
"check_pymc_params_match_rv_op",
509515
"check_rv_size",
510-
"check_pymc_dist_matches_reference",
516+
"check_pymc_draws_match_reference",
511517
]
512518

513519

@@ -521,7 +527,7 @@ class TestGumbel(BaseTestDistribution):
521527
tests_to_run = [
522528
"check_pymc_params_match_rv_op",
523529
"check_rv_size",
524-
"check_pymc_dist_matches_reference",
530+
"check_pymc_draws_match_reference",
525531
]
526532

527533

@@ -535,7 +541,7 @@ class TestNormal(BaseTestDistribution):
535541
tests_to_run = [
536542
"check_pymc_params_match_rv_op",
537543
"check_rv_size",
538-
"check_pymc_dist_matches_reference",
544+
"check_pymc_draws_match_reference",
539545
]
540546

541547

@@ -595,7 +601,7 @@ class TestBeta(BaseTestDistribution):
595601
tests_to_run = [
596602
"check_pymc_params_match_rv_op",
597603
"check_rv_size",
598-
"check_pymc_params_match_rv_op",
604+
"check_pymc_draws_match_reference",
599605
]
600606

601607

0 commit comments

Comments
 (0)