Skip to content

Commit 946cf34

Browse files
matteo-palliniricardoV94
authored andcommitted
V4 update test framework for distributions random method (#4608)
* Update tests following distributions refactoring The distributions refactoring moves the random variable sampling to aesara. This relies on numpy and scipy random variables implementation. So, now the only thing we care about testing is that the parametrization on the PyMC side is sendible given the one on the Aesara side (effectively the numpy/scipy one) More details can be found on issue #4554 #4554 * Change tests for more refactored distributions. More details can be found on issue #4554 #4554 * Change tests for refactored distributions More details can be found on issue #4554 #4554 * Remove tests for random variable samples shape and size Most of the random variable logic has been moved to aesara, as well as most of the relative tests. More details can be found on issue #4554 * Fix test for half cauchy, renmae mv normal tests and add test for Bernoulli * Add test checking PyMC samples match the aesara ones Also mark test_categorical as expected to fail due to bug on aesara side. The bug is going to be fixed with 2.0.5 release, so we need to bump the version for categorical and the test to pass. * Move Aesara to 2.0.5 to include Gumbel distribution * Enamble exponential and gamma tests following bug-fix * Enable categorical test following aesara version bump to 2.0.5 and relative bug-fix * Few small cosmetic changes: - replace list of tuples with dict - rename 1 method - move pymc_dist as first argument in function call - replace list(params) with params.copy() * Remove redundant tests * Further refactoring The refactoring should make it possible testing both the distribution parametrization and sampled values according to need, as well as any other future test. More details on PR #4608 * Add size tests to new rv testing framework * Add tests for multivariate and for univariate multi-parameters * remove test already covered in aesara * fix few names * Remove "distribution" from test class names * Add discrete Weibull, improve Beta and some minor refactoring * Fix typos in checks naming and add sanity check Co-authored-by: Ricardo <[email protected]>
1 parent 51b22b4 commit 946cf34

File tree

4 files changed

+396
-327
lines changed

4 files changed

+396
-327
lines changed

pymc3/distributions/discrete.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -713,24 +713,23 @@ def NegBinom(a, m, x):
713713

714714
@classmethod
715715
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
716-
n, p = cls.get_mu_alpha(mu, alpha, p, n)
716+
n, p = cls.get_n_p(mu, alpha, p, n)
717717
n = at.as_tensor_variable(floatX(n))
718718
p = at.as_tensor_variable(floatX(p))
719719
return super().dist([n, p], *args, **kwargs)
720720

721721
@classmethod
722-
def get_mu_alpha(cls, mu=None, alpha=None, p=None, n=None):
722+
def get_n_p(cls, mu=None, alpha=None, p=None, n=None):
723723
if n is None:
724724
if alpha is not None:
725-
n = at.as_tensor_variable(floatX(alpha))
725+
n = alpha
726726
else:
727727
raise ValueError("Incompatible parametrization. Must specify either alpha or n.")
728728
elif alpha is not None:
729729
raise ValueError("Incompatible parametrization. Can't specify both alpha and n.")
730730

731731
if p is None:
732732
if mu is not None:
733-
mu = at.as_tensor_variable(floatX(mu))
734733
p = n / (mu + n)
735734
else:
736735
raise ValueError("Incompatible parametrization. Must specify either mu or p.")

pymc3/tests/helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
class SeededTest:
2929
random_seed = 20160911
30+
random_state = None
3031

3132
@classmethod
3233
def setup_class(cls):
@@ -40,6 +41,11 @@ def setup_method(self):
4041
def teardown_method(self):
4142
set_at_rng(self.old_at_rng)
4243

44+
def get_random_state(self, reset=False):
45+
if self.random_state is None or reset:
46+
self.random_state = nr.RandomState(self.random_seed)
47+
return self.random_state
48+
4349

4450
class LoggingHandler(BufferingHandler):
4551
def __init__(self, matcher):

0 commit comments

Comments
 (0)