Skip to content

Commit af87a5d

Browse files
chritterIcyshaman
authored andcommitted
Introduced test for _get_scaling_
Co-authored-by: Icyshaman <[email protected]>
1 parent 53ad689 commit af87a5d

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

pymc/tests/test_logprob.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@
3333
from pymc.aesaraf import floatX, walk_model
3434
from pymc.distributions.continuous import HalfFlat, Normal, TruncatedNormal, Uniform
3535
from pymc.distributions.discrete import Bernoulli
36-
from pymc.distributions.logprob import ignore_logprob, joint_logpt, logcdf, logp
36+
from pymc.distributions.logprob import (
37+
_get_scaling,
38+
ignore_logprob,
39+
joint_logpt,
40+
logcdf,
41+
logp,
42+
)
3743
from pymc.model import Model, Potential
3844
from pymc.tests.helpers import select_by_precision
3945

@@ -43,6 +49,44 @@ def assert_no_rvs(var):
4349
return var
4450

4551

52+
def test_get_scaling():
53+
54+
assert _get_scaling(None, (2, 3), 2).eval() == 1
55+
# ndim >=1 & ndim<1
56+
assert _get_scaling(45, (2, 3), 1).eval() == 22.5
57+
assert _get_scaling(45, (2, 3), 0).eval() == 45
58+
59+
# list or tuple tests
60+
# total_size contains other than Ellipsis, None and Int
61+
with pytest.raises(TypeError):
62+
_get_scaling([2, 4, 5, 9, 11.5], (2, 3), 2)
63+
# check with Ellipsis
64+
with pytest.raises(ValueError):
65+
_get_scaling([1, 2, 5, Ellipsis, Ellipsis], (2, 3), 2)
66+
with pytest.raises(ValueError):
67+
_get_scaling([1, 2, 5, Ellipsis], (2, 3), 2)
68+
69+
assert _get_scaling([Ellipsis], (2, 3), 2).eval() == 1
70+
71+
assert _get_scaling([4, 5, 9, Ellipsis, 32, 12], (2, 3, 2), 5).eval() == 960
72+
assert _get_scaling([4, 5, 9, Ellipsis], (2, 3, 2), 5).eval() == 15
73+
# total_size with no Ellipsis (end = [ ])
74+
with pytest.raises(ValueError):
75+
_get_scaling([1, 2, 5], (2, 3), 2)
76+
77+
assert _get_scaling([], (2, 3), 2).eval() == 1
78+
assert _get_scaling((), (2, 3), 2).eval() == 1
79+
# total_size invalid type
80+
with pytest.raises(TypeError):
81+
_get_scaling({1, 2, 5}, (2, 3), 2)
82+
83+
# test with rvar from model graph
84+
with Model() as m2:
85+
rv_var = Uniform("a", 0.0, 1.0)
86+
total_size = []
87+
assert _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim).eval() == 1.0
88+
89+
4690
def test_joint_logpt_basic():
4791
"""Make sure we can compute a log-likelihood for a hierarchical model with transforms."""
4892

0 commit comments

Comments
 (0)