Skip to content

Commit 53dfb75

Browse files
Remove tests for random variable samples shape and size
Most of the random variable logic has been moved to aesara, as well as most of the relative tests. More details can be found on issue pymc-devs#4554
1 parent 069533f commit 53dfb75

File tree

1 file changed

+0
-111
lines changed

1 file changed

+0
-111
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,6 @@ class TestGaussianRandomWalk(BaseTestCases.BaseTestCase):
245245
default_shape = (1,)
246246

247247

248-
@pytest.mark.skip(reason="This test is covered by Aesara")
249-
class TestNormal(BaseTestCases.BaseTestCase):
250-
distribution = pm.Normal
251-
params = {"mu": 0.0, "tau": 1.0}
252-
253-
254248
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
255249
class TestTruncatedNormal(BaseTestCases.BaseTestCase):
256250
distribution = pm.TruncatedNormal
@@ -275,18 +269,6 @@ class TestSkewNormal(BaseTestCases.BaseTestCase):
275269
params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
276270

277271

278-
@pytest.mark.skip(reason="This test is covered by Aesara")
279-
class TestHalfNormal(BaseTestCases.BaseTestCase):
280-
distribution = pm.HalfNormal
281-
params = {"tau": 1.0}
282-
283-
284-
@pytest.mark.skip(reason="This test is covered by Aesara")
285-
class TestUniform(BaseTestCases.BaseTestCase):
286-
distribution = pm.Uniform
287-
params = {"lower": 0.0, "upper": 1.0}
288-
289-
290272
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
291273
class TestTriangular(BaseTestCases.BaseTestCase):
292274
distribution = pm.Triangular
@@ -310,12 +292,6 @@ class TestKumaraswamy(BaseTestCases.BaseTestCase):
310292
params = {"a": 1.0, "b": 1.0}
311293

312294

313-
@pytest.mark.skip(reason="This test is covered by Aesara")
314-
class TestExponential(BaseTestCases.BaseTestCase):
315-
distribution = pm.Exponential
316-
params = {"lam": 1.0}
317-
318-
319295
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
320296
class TestLaplace(BaseTestCases.BaseTestCase):
321297
distribution = pm.Laplace
@@ -346,30 +322,6 @@ class TestPareto(BaseTestCases.BaseTestCase):
346322
params = {"alpha": 0.5, "m": 1.0}
347323

348324

349-
@pytest.mark.skip(reason="This test is covered by Aesara")
350-
class TestCauchy(BaseTestCases.BaseTestCase):
351-
distribution = pm.Cauchy
352-
params = {"alpha": 1.0, "beta": 1.0}
353-
354-
355-
@pytest.mark.skip(reason="This test is covered by Aesara")
356-
class TestHalfCauchy(BaseTestCases.BaseTestCase):
357-
distribution = pm.HalfCauchy
358-
params = {"beta": 1.0}
359-
360-
361-
@pytest.mark.skip(reason="This test is covered by Aesara")
362-
class TestGamma(BaseTestCases.BaseTestCase):
363-
distribution = pm.Gamma
364-
params = {"alpha": 1.0, "beta": 1.0}
365-
366-
367-
@pytest.mark.skip(reason="This test is covered by Aesara")
368-
class TestInverseGamma(BaseTestCases.BaseTestCase):
369-
distribution = pm.InverseGamma
370-
params = {"alpha": 0.5, "beta": 0.5}
371-
372-
373325
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
374326
class TestChiSquared(BaseTestCases.BaseTestCase):
375327
distribution = pm.ChiSquared
@@ -412,42 +364,18 @@ class TestLogitNormal(BaseTestCases.BaseTestCase):
412364
params = {"mu": 0.0, "sigma": 1.0}
413365

414366

415-
@pytest.mark.skip(reason="This test is covered by Aesara")
416-
class TestBinomial(BaseTestCases.BaseTestCase):
417-
distribution = pm.Binomial
418-
params = {"n": 5, "p": 0.5}
419-
420-
421367
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
422368
class TestBetaBinomial(BaseTestCases.BaseTestCase):
423369
distribution = pm.BetaBinomial
424370
params = {"n": 5, "alpha": 1.0, "beta": 1.0}
425371

426372

427-
@pytest.mark.skip(reason="This test is covered by Aesara")
428-
class TestBernoulli(BaseTestCases.BaseTestCase):
429-
distribution = pm.Bernoulli
430-
params = {"p": 0.5}
431-
432-
433373
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
434374
class TestDiscreteWeibull(BaseTestCases.BaseTestCase):
435375
distribution = pm.DiscreteWeibull
436376
params = {"q": 0.25, "beta": 2.0}
437377

438378

439-
@pytest.mark.skip(reason="This test is covered by Aesara")
440-
class TestPoisson(BaseTestCases.BaseTestCase):
441-
distribution = pm.Poisson
442-
params = {"mu": 1.0}
443-
444-
445-
@pytest.mark.skip(reason="This test is covered by Aesara")
446-
class TestNegativeBinomial(BaseTestCases.BaseTestCase):
447-
distribution = pm.NegativeBinomial
448-
params = {"mu": 1.0, "alpha": 1.0}
449-
450-
451379
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
452380
class TestConstant(BaseTestCases.BaseTestCase):
453381
distribution = pm.Constant
@@ -496,45 +424,6 @@ class TestMoyal(BaseTestCases.BaseTestCase):
496424
params = {"mu": 0.0, "sigma": 1.0}
497425

498426

499-
@pytest.mark.skip(reason="This test is covered by Aesara")
500-
class TestCategorical(BaseTestCases.BaseTestCase):
501-
distribution = pm.Categorical
502-
params = {"p": np.ones(BaseTestCases.BaseTestCase.shape)}
503-
504-
def get_random_variable(
505-
self, shape, with_vector_params=False, **kwargs
506-
): # don't transform categories
507-
return super().get_random_variable(shape, with_vector_params=False, **kwargs)
508-
509-
def test_probability_vector_shape(self):
510-
"""Check that if a 2d array of probabilities are passed to categorical correct shape is returned"""
511-
p = np.ones((10, 5))
512-
assert pm.Categorical.dist(p=p).random().shape == (10,)
513-
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 10)
514-
p = np.ones((3, 7, 5))
515-
assert pm.Categorical.dist(p=p).random().shape == (3, 7)
516-
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 3, 7)
517-
518-
519-
@pytest.mark.skip(reason="This test is covered by Aesara")
520-
class TestDirichlet(SeededTest):
521-
@pytest.mark.parametrize(
522-
"shape, size",
523-
[
524-
((2), (1)),
525-
((2), (2)),
526-
((2, 2), (2, 100)),
527-
((3, 4), (3, 4)),
528-
((3, 4), (3, 4, 100)),
529-
((3, 4), (100)),
530-
((3, 4), (1)),
531-
],
532-
)
533-
def test_dirichlet_random_shape(self, shape, size):
534-
out_shape = to_tuple(size) + to_tuple(shape)
535-
assert pm.Dirichlet.dist(a=np.ones(shape)).random(size=size).shape == out_shape
536-
537-
538427
class TestCorrectParametrizationMappingPymcToScipy(SeededTest):
539428
@staticmethod
540429
def get_inputs_from_apply_node_outputs(outputs):

0 commit comments

Comments
 (0)