33
33
from pymc .aesaraf import floatX , walk_model
34
34
from pymc .distributions .continuous import HalfFlat , Normal , TruncatedNormal , Uniform
35
35
from pymc .distributions .discrete import Bernoulli
36
- from pymc .distributions .logprob import ignore_logprob , joint_logpt , logcdf , logp
36
+ from pymc .distributions .logprob import (
37
+ _get_scaling ,
38
+ ignore_logprob ,
39
+ joint_logpt ,
40
+ logcdf ,
41
+ logp ,
42
+ )
37
43
from pymc .model import Model , Potential
38
44
from pymc .tests .helpers import select_by_precision
39
45
@@ -43,6 +49,44 @@ def assert_no_rvs(var):
43
49
return var
44
50
45
51
52
+ def test_get_scaling ():
53
+
54
+ assert _get_scaling (None , (2 , 3 ), 2 ).eval () == 1
55
+ # ndim >=1 & ndim<1
56
+ assert _get_scaling (45 , (2 , 3 ), 1 ).eval () == 22.5
57
+ assert _get_scaling (45 , (2 , 3 ), 0 ).eval () == 45
58
+
59
+ # list or tuple tests
60
+ # total_size contains other than Ellipsis, None and Int
61
+ with pytest .raises (TypeError ):
62
+ _get_scaling ([2 , 4 , 5 , 9 , 11.5 ], (2 , 3 ), 2 )
63
+ # check with Ellipsis
64
+ with pytest .raises (ValueError ):
65
+ _get_scaling ([1 , 2 , 5 , Ellipsis , Ellipsis ], (2 , 3 ), 2 )
66
+ with pytest .raises (ValueError ):
67
+ _get_scaling ([1 , 2 , 5 , Ellipsis ], (2 , 3 ), 2 )
68
+
69
+ assert _get_scaling ([Ellipsis ], (2 , 3 ), 2 ).eval () == 1
70
+
71
+ assert _get_scaling ([4 , 5 , 9 , Ellipsis , 32 , 12 ], (2 , 3 , 2 ), 5 ).eval () == 960
72
+ assert _get_scaling ([4 , 5 , 9 , Ellipsis ], (2 , 3 , 2 ), 5 ).eval () == 15
73
+ # total_size with no Ellipsis (end = [ ])
74
+ with pytest .raises (ValueError ):
75
+ _get_scaling ([1 , 2 , 5 ], (2 , 3 ), 2 )
76
+
77
+ assert _get_scaling ([], (2 , 3 ), 2 ).eval () == 1
78
+ assert _get_scaling ((), (2 , 3 ), 2 ).eval () == 1
79
+ # total_size invalid type
80
+ with pytest .raises (TypeError ):
81
+ _get_scaling ({1 , 2 , 5 }, (2 , 3 ), 2 )
82
+
83
+ # test with rvar from model graph
84
+ with Model () as m2 :
85
+ rv_var = Uniform ("a" , 0.0 , 1.0 )
86
+ total_size = []
87
+ assert _get_scaling (total_size , shape = rv_var .shape , ndim = rv_var .ndim ).eval () == 1.0
88
+
89
+
46
90
def test_joint_logpt_basic ():
47
91
"""Make sure we can compute a log-likelihood for a hierarchical model with transforms."""
48
92
0 commit comments