Skip to content

Commit 3d28087

Browse files
Add discrete Weibull, improve Beta and some minor refactoring
1 parent 706308e commit 3d28087

File tree

1 file changed

+40
-28
lines changed

1 file changed

+40
-28
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from pymc3.aesaraf import change_rv_size, floatX, intX
3434
from pymc3.distributions.continuous import get_tau_sigma
35+
from pymc3.distributions.dist_math import clipped_beta_rvs
3536
from pymc3.distributions.multivariate import quaddist_matrix
3637
from pymc3.distributions.shape_utils import to_tuple
3738
from pymc3.exceptions import ShapeError
@@ -278,11 +279,6 @@ class TestWald(BaseTestCases.BaseTestCase):
278279
params = {"mu": 1.0, "lam": 1.0, "alpha": 0.0}
279280

280281

281-
class TestBeta(BaseTestCases.BaseTestCase):
282-
distribution = pm.Beta
283-
params = {"alpha": 1.0, "beta": 1.0}
284-
285-
286282
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
287283
class TestKumaraswamy(BaseTestCases.BaseTestCase):
288284
distribution = pm.Kumaraswamy
@@ -355,11 +351,6 @@ class TestBetaBinomial(BaseTestCases.BaseTestCase):
355351
params = {"n": 5, "alpha": 1.0, "beta": 1.0}
356352

357353

358-
class TestDiscreteWeibull(BaseTestCases.BaseTestCase):
359-
distribution = pm.DiscreteWeibull
360-
params = {"q": 0.25, "beta": 2.0}
361-
362-
363354
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
364355
class TestConstant(BaseTestCases.BaseTestCase):
365356
distribution = pm.Constant
@@ -426,17 +417,10 @@ def test_distribution(self):
426417
self._instantiate_pymc_rv()
427418
if self.reference_dist is not None:
428419
self.reference_dist_draws = self.reference_dist()(
429-
**self.reference_dist_params, size=self.size
420+
size=self.size, **self.reference_dist_params
430421
)
431-
for test_name in self.tests_to_run:
432-
self.run_test(test_name)
433-
434-
def run_test(self, test_name):
435-
{
436-
"check_pymc_dist_matches_reference": self._check_pymc_draws_match_reference,
437-
"check_pymc_params_match_rv_op": self._check_pymc_params_match_rv_op,
438-
"check_rv_size": self._check_rv_size,
439-
}[test_name]()
422+
for check_name in self.tests_to_run:
423+
getattr(self, check_name)()
440424

441425
def _instantiate_pymc_rv(self, dist_params=None):
442426
params = dist_params if dist_params else self.pymc_dist_params
@@ -448,25 +432,22 @@ def _instantiate_pymc_rv(self, dist_params=None):
448432
name=f"{self.pymc_dist.rv_op.name}_test",
449433
)
450434

451-
def _check_pymc_draws_match_reference(self):
435+
def check_pymc_draws_match_reference(self):
452436
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
453437
self._instantiate_pymc_rv()
454438
assert_array_almost_equal(
455439
self.pymc_rv.eval(), self.reference_dist_draws, decimal=self.decimal
456440
)
457441

458-
def _check_pymc_params_match_rv_op(self) -> None:
459-
try:
460-
aesera_dist_inputs = self.pymc_rv.get_parents()[0].inputs[3:]
461-
except:
462-
raise Exception("Parent Apply node missing from output")
442+
def check_pymc_params_match_rv_op(self) -> None:
443+
aesera_dist_inputs = self.pymc_rv.get_parents()[0].inputs[3:]
463444
assert len(self.expected_rv_op_params) == len(aesera_dist_inputs)
464445
for (expected_name, expected_value), actual_variable in zip(
465446
self.expected_rv_op_params.items(), aesera_dist_inputs
466447
):
467448
assert_almost_equal(expected_value, actual_variable.eval(), decimal=self.decimal)
468449

469-
def _check_rv_size(self):
450+
def check_rv_size(self):
470451
# test sizes
471452
sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
472453
sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
@@ -508,6 +489,28 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
508489
)
509490

510491

492+
class TestDiscreteWeibull(BaseTestDistribution):
493+
def discrete_weibul_rng_fn(self):
494+
p = seeded_numpy_distribution_builder("uniform")
495+
return (
496+
lambda size, q, beta: np.ceil(
497+
np.power(np.log(1 - p(self)(size=size)) / np.log(q), 1.0 / beta)
498+
)
499+
- 1
500+
)
501+
502+
pymc_dist = pm.DiscreteWeibull
503+
pymc_dist_params = {"q": 0.25, "beta": 2.0}
504+
expected_rv_op_params = {"q": 0.25, "beta": 2.0}
505+
reference_dist_params = {"q": 0.25, "beta": 2.0}
506+
reference_dist = discrete_weibul_rng_fn
507+
tests_to_run = [
508+
"check_pymc_params_match_rv_op",
509+
"check_rv_size",
510+
"check_pymc_dist_matches_reference",
511+
]
512+
513+
511514
class TestGumbel(BaseTestDistribution):
512515
pymc_dist = pm.Gumbel
513516
pymc_dist_params = {"mu": 1.5, "beta": 3.0}
@@ -584,7 +587,16 @@ class TestBeta(BaseTestDistribution):
584587
pymc_dist = pm.Beta
585588
pymc_dist_params = {"alpha": 2.0, "beta": 5.0}
586589
expected_rv_op_params = {"alpha": 2.0, "beta": 5.0}
587-
tests_to_run = ["check_pymc_params_match_rv_op"]
590+
reference_dist_params = {"a": 2.0, "b": 5.0}
591+
size = 15
592+
reference_dist = lambda self: functools.partial(
593+
clipped_beta_rvs, random_state=self.get_random_state()
594+
)
595+
tests_to_run = [
596+
"check_pymc_params_match_rv_op",
597+
"check_rv_size",
598+
"check_pymc_params_match_rv_op",
599+
]
588600

589601

590602
class TestBetaMuSigma(BaseTestDistribution):

0 commit comments

Comments
 (0)