diff --git a/pymc/tests/test_logprob.py b/pymc/tests/test_logprob.py index 74952c86a3..db924b88d2 100644 --- a/pymc/tests/test_logprob.py +++ b/pymc/tests/test_logprob.py @@ -33,7 +33,13 @@ from pymc.aesaraf import floatX, walk_model from pymc.distributions.continuous import HalfFlat, Normal, TruncatedNormal, Uniform from pymc.distributions.discrete import Bernoulli -from pymc.distributions.logprob import ignore_logprob, joint_logpt, logcdf, logp +from pymc.distributions.logprob import ( + _get_scaling, + ignore_logprob, + joint_logpt, + logcdf, + logp, +) from pymc.model import Model, Potential from pymc.tests.helpers import select_by_precision @@ -43,6 +49,53 @@ def assert_no_rvs(var): return var +def test_get_scaling(): + + assert _get_scaling(None, (2, 3), 2).eval() == 1 + # ndim >=1 & ndim<1 + assert _get_scaling(45, (2, 3), 1).eval() == 22.5 + assert _get_scaling(45, (2, 3), 0).eval() == 45 + + # list or tuple tests + # total_size contains other than Ellipsis, None and Int + with pytest.raises(TypeError, match="Unrecognized `total_size` type"): + _get_scaling([2, 4, 5, 9, 11.5], (2, 3), 2) + # check with Ellipsis + with pytest.raises(ValueError, match="Double Ellipsis in `total_size` is restricted"): + _get_scaling([1, 2, 5, Ellipsis, Ellipsis], (2, 3), 2) + with pytest.raises( + ValueError, + match="Length of `total_size` is too big, number of scalings is bigger that ndim", + ): + _get_scaling([1, 2, 5, Ellipsis], (2, 3), 2) + + assert _get_scaling([Ellipsis], (2, 3), 2).eval() == 1 + + assert _get_scaling([4, 5, 9, Ellipsis, 32, 12], (2, 3, 2), 5).eval() == 960 + assert _get_scaling([4, 5, 9, Ellipsis], (2, 3, 2), 5).eval() == 15 + # total_size with no Ellipsis (end = [ ]) + with pytest.raises( + ValueError, + match="Length of `total_size` is too big, number of scalings is bigger that ndim", + ): + _get_scaling([1, 2, 5], (2, 3), 2) + + assert _get_scaling([], (2, 3), 2).eval() == 1 + assert _get_scaling((), (2, 3), 2).eval() == 1 + # total_size invalid type + with pytest.raises( + TypeError, + match="Unrecognized `total_size` type, expected int or list of ints, got {1, 2, 5}", + ): + _get_scaling({1, 2, 5}, (2, 3), 2) + + # test with rvar from model graph + with Model() as m2: + rv_var = Uniform("a", 0.0, 1.0) + total_size = [] + assert _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim).eval() == 1.0 + + def test_joint_logpt_basic(): """Make sure we can compute a log-likelihood for a hierarchical model with transforms."""