Skip to content

Commit 49eecfb

Browse files
chritterIcyshamanchritter
authored
Introduced tests for pymc.distributions.logprob._get_scaling_ (#5544)
* Introduced test for _get_scaling_ Co-authored-by: Icyshaman <[email protected]> Co-authored-by: chritter <[email protected]>
1 parent 53ad689 commit 49eecfb

File tree

1 file changed

+54
-1
lines changed

1 file changed

+54
-1
lines changed

pymc/tests/test_logprob.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@
3333
from pymc.aesaraf import floatX, walk_model
3434
from pymc.distributions.continuous import HalfFlat, Normal, TruncatedNormal, Uniform
3535
from pymc.distributions.discrete import Bernoulli
36-
from pymc.distributions.logprob import ignore_logprob, joint_logpt, logcdf, logp
36+
from pymc.distributions.logprob import (
37+
_get_scaling,
38+
ignore_logprob,
39+
joint_logpt,
40+
logcdf,
41+
logp,
42+
)
3743
from pymc.model import Model, Potential
3844
from pymc.tests.helpers import select_by_precision
3945

@@ -43,6 +49,53 @@ def assert_no_rvs(var):
4349
return var
4450

4551

52+
def test_get_scaling():
53+
54+
assert _get_scaling(None, (2, 3), 2).eval() == 1
55+
# ndim >=1 & ndim<1
56+
assert _get_scaling(45, (2, 3), 1).eval() == 22.5
57+
assert _get_scaling(45, (2, 3), 0).eval() == 45
58+
59+
# list or tuple tests
60+
# total_size contains other than Ellipsis, None and Int
61+
with pytest.raises(TypeError, match="Unrecognized `total_size` type"):
62+
_get_scaling([2, 4, 5, 9, 11.5], (2, 3), 2)
63+
# check with Ellipsis
64+
with pytest.raises(ValueError, match="Double Ellipsis in `total_size` is restricted"):
65+
_get_scaling([1, 2, 5, Ellipsis, Ellipsis], (2, 3), 2)
66+
with pytest.raises(
67+
ValueError,
68+
match="Length of `total_size` is too big, number of scalings is bigger that ndim",
69+
):
70+
_get_scaling([1, 2, 5, Ellipsis], (2, 3), 2)
71+
72+
assert _get_scaling([Ellipsis], (2, 3), 2).eval() == 1
73+
74+
assert _get_scaling([4, 5, 9, Ellipsis, 32, 12], (2, 3, 2), 5).eval() == 960
75+
assert _get_scaling([4, 5, 9, Ellipsis], (2, 3, 2), 5).eval() == 15
76+
# total_size with no Ellipsis (end = [ ])
77+
with pytest.raises(
78+
ValueError,
79+
match="Length of `total_size` is too big, number of scalings is bigger that ndim",
80+
):
81+
_get_scaling([1, 2, 5], (2, 3), 2)
82+
83+
assert _get_scaling([], (2, 3), 2).eval() == 1
84+
assert _get_scaling((), (2, 3), 2).eval() == 1
85+
# total_size invalid type
86+
with pytest.raises(
87+
TypeError,
88+
match="Unrecognized `total_size` type, expected int or list of ints, got {1, 2, 5}",
89+
):
90+
_get_scaling({1, 2, 5}, (2, 3), 2)
91+
92+
# test with rvar from model graph
93+
with Model() as m2:
94+
rv_var = Uniform("a", 0.0, 1.0)
95+
total_size = []
96+
assert _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim).eval() == 1.0
97+
98+
4699
def test_joint_logpt_basic():
47100
"""Make sure we can compute a log-likelihood for a hierarchical model with transforms."""
48101

0 commit comments

Comments
 (0)