@@ -1127,31 +1127,31 @@ def test_cosh_rv_transform():
1127
1127
)
1128
1128
1129
1129
1130
- TRANSFORMATIONS = {
1131
- "log1p" : (pt .log1p , lambda x : pt .log (1 + x )),
1132
- "softplus" : (pt .softplus , lambda x : pt .log (1 + pt .exp (x ))),
1133
- "log1mexp" : (pt .log1mexp , lambda x : pt .log (1 - pt .exp (x ))),
1134
- "log2" : (pt .log2 , lambda x : pt .log (x ) / pt .log (2 )),
1135
- "log10" : (pt .log10 , lambda x : pt .log (x ) / pt .log (10 )),
1136
- "exp2" : (pt .exp2 , lambda x : pt .exp (pt .log (2 ) * x )),
1137
- "expm1" : (pt .expm1 , lambda x : pt .exp (x ) - 1 ),
1138
- "sigmoid" : (pt .sigmoid , lambda x : 1 / (1 + pt .exp (- x ))),
1139
- }
1140
-
1141
-
1142
- @pytest .mark .parametrize ("transform" , TRANSFORMATIONS .keys ())
1143
- def test_special_log_exp_transforms (transform ):
1130
+ @pytest .mark .parametrize (
1131
+ "canonical_func,raw_func" ,
1132
+ [
1133
+ (pt .log1p , lambda x : pt .log (1 + x )),
1134
+ (pt .softplus , lambda x : pt .log (1 + pt .exp (x ))),
1135
+ (pt .log1mexp , lambda x : pt .log (1 - pt .exp (x ))),
1136
+ (pt .log2 , lambda x : pt .log (x ) / pt .log (2 )),
1137
+ (pt .log10 , lambda x : pt .log (x ) / pt .log (10 )),
1138
+ (pt .exp2 , lambda x : pt .exp (pt .log (2 ) * x )),
1139
+ (pt .expm1 , lambda x : pt .exp (x ) - 1 ),
1140
+ (pt .sigmoid , lambda x : 1 / (1 + pt .exp (- x ))),
1141
+ (pt .sigmoid , lambda x : pt .exp (x ) / (1 + pt .exp (x ))),
1142
+ ],
1143
+ )
1144
+ def test_special_log_exp_transforms (canonical_func , raw_func ):
1144
1145
base_rv = pt .random .normal (name = "base_rv" )
1145
1146
vv = pt .scalar ("vv" )
1146
1147
1147
- transform_func , ref_func = TRANSFORMATIONS [transform ]
1148
- transformed_rv = transform_func (base_rv )
1149
- ref_transformed_rv = ref_func (base_rv )
1148
+ transformed_rv = raw_func (base_rv )
1149
+ ref_transformed_rv = canonical_func (base_rv )
1150
1150
1151
1151
logp_test = logp (transformed_rv , vv )
1152
1152
logp_ref = logp (ref_transformed_rv , vv )
1153
1153
1154
- if transform in [ " log2" , " log10" ] :
1154
+ if canonical_func in ( pt . log2 , pt . log10 ) :
1155
1155
# in the cases of log2 and log10 floating point inprecision causes failure
1156
1156
# from equal_computations so evaluate logp and check all close instead
1157
1157
vv_test = np .array (0.25 )
0 commit comments