33
33
from pymc .aesaraf import floatX , walk_model
34
34
from pymc .distributions .continuous import HalfFlat , Normal , TruncatedNormal , Uniform
35
35
from pymc .distributions .discrete import Bernoulli
36
- from pymc .distributions .logprob import ignore_logprob , joint_logpt , logcdf , logp
36
+ from pymc .distributions .logprob import (
37
+ _get_scaling ,
38
+ ignore_logprob ,
39
+ joint_logpt ,
40
+ logcdf ,
41
+ logp ,
42
+ )
37
43
from pymc .model import Model , Potential
38
44
from pymc .tests .helpers import select_by_precision
39
45
@@ -43,6 +49,53 @@ def assert_no_rvs(var):
43
49
return var
44
50
45
51
52
+ def test_get_scaling ():
53
+
54
+ assert _get_scaling (None , (2 , 3 ), 2 ).eval () == 1
55
+ # ndim >=1 & ndim<1
56
+ assert _get_scaling (45 , (2 , 3 ), 1 ).eval () == 22.5
57
+ assert _get_scaling (45 , (2 , 3 ), 0 ).eval () == 45
58
+
59
+ # list or tuple tests
60
+ # total_size contains other than Ellipsis, None and Int
61
+ with pytest .raises (TypeError , match = "Unrecognized `total_size` type" ):
62
+ _get_scaling ([2 , 4 , 5 , 9 , 11.5 ], (2 , 3 ), 2 )
63
+ # check with Ellipsis
64
+ with pytest .raises (ValueError , match = "Double Ellipsis in `total_size` is restricted" ):
65
+ _get_scaling ([1 , 2 , 5 , Ellipsis , Ellipsis ], (2 , 3 ), 2 )
66
+ with pytest .raises (
67
+ ValueError ,
68
+ match = "Length of `total_size` is too big, number of scalings is bigger that ndim" ,
69
+ ):
70
+ _get_scaling ([1 , 2 , 5 , Ellipsis ], (2 , 3 ), 2 )
71
+
72
+ assert _get_scaling ([Ellipsis ], (2 , 3 ), 2 ).eval () == 1
73
+
74
+ assert _get_scaling ([4 , 5 , 9 , Ellipsis , 32 , 12 ], (2 , 3 , 2 ), 5 ).eval () == 960
75
+ assert _get_scaling ([4 , 5 , 9 , Ellipsis ], (2 , 3 , 2 ), 5 ).eval () == 15
76
+ # total_size with no Ellipsis (end = [ ])
77
+ with pytest .raises (
78
+ ValueError ,
79
+ match = "Length of `total_size` is too big, number of scalings is bigger that ndim" ,
80
+ ):
81
+ _get_scaling ([1 , 2 , 5 ], (2 , 3 ), 2 )
82
+
83
+ assert _get_scaling ([], (2 , 3 ), 2 ).eval () == 1
84
+ assert _get_scaling ((), (2 , 3 ), 2 ).eval () == 1
85
+ # total_size invalid type
86
+ with pytest .raises (
87
+ TypeError ,
88
+ match = "Unrecognized `total_size` type, expected int or list of ints, got {1, 2, 5}" ,
89
+ ):
90
+ _get_scaling ({1 , 2 , 5 }, (2 , 3 ), 2 )
91
+
92
+ # test with rvar from model graph
93
+ with Model () as m2 :
94
+ rv_var = Uniform ("a" , 0.0 , 1.0 )
95
+ total_size = []
96
+ assert _get_scaling (total_size , shape = rv_var .shape , ndim = rv_var .ndim ).eval () == 1.0
97
+
98
+
46
99
def test_joint_logpt_basic ():
47
100
"""Make sure we can compute a log-likelihood for a hierarchical model with transforms."""
48
101
0 commit comments