Skip to content

Commit 0bd2d65

Browse files
aliakbarsMarcoGorelliAli Septiandri
authored
Run black on distributions/ and */__init__.py (#4139)
* Run black on distributions/ and */__init__.py * Run black on step_methods/ * Revert two files * Update pymc3/distributions/bound.py Co-authored-by: Marco Gorelli <[email protected]> * Revert blackened metropolis.py * Revert blackened files * Fix error in quadpotential.py Co-authored-by: Marco Gorelli <[email protected]> Co-authored-by: Ali Septiandri <[email protected]>
1 parent 77fe82a commit 0bd2d65

21 files changed

+347
-415
lines changed

pymc3/distributions/__init__.py

Lines changed: 76 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -100,78 +100,79 @@
100100

101101
from .bound import Bound
102102

103-
__all__ = ['Uniform',
104-
'Flat',
105-
'HalfFlat',
106-
'TruncatedNormal',
107-
'Normal',
108-
'Beta',
109-
'Kumaraswamy',
110-
'Exponential',
111-
'Laplace',
112-
'StudentT',
113-
'Cauchy',
114-
'HalfCauchy',
115-
'Gamma',
116-
'Weibull',
117-
'Bound',
118-
'Lognormal',
119-
'HalfStudentT',
120-
'ChiSquared',
121-
'HalfNormal',
122-
'Wald',
123-
'Pareto',
124-
'InverseGamma',
125-
'ExGaussian',
126-
'VonMises',
127-
'Binomial',
128-
'BetaBinomial',
129-
'Bernoulli',
130-
'Poisson',
131-
'NegativeBinomial',
132-
'ConstantDist',
133-
'Constant',
134-
'ZeroInflatedPoisson',
135-
'ZeroInflatedNegativeBinomial',
136-
'ZeroInflatedBinomial',
137-
'DiscreteUniform',
138-
'Geometric',
139-
'Categorical',
140-
'OrderedLogistic',
141-
'DensityDist',
142-
'Distribution',
143-
'Continuous',
144-
'Discrete',
145-
'NoDistribution',
146-
'TensorType',
147-
'MvNormal',
148-
'MatrixNormal',
149-
'KroneckerNormal',
150-
'MvStudentT',
151-
'Dirichlet',
152-
'Multinomial',
153-
'Wishart',
154-
'WishartBartlett',
155-
'LKJCholeskyCov',
156-
'LKJCorr',
157-
'AR1',
158-
'AR',
159-
'GaussianRandomWalk',
160-
'MvGaussianRandomWalk',
161-
'MvStudentTRandomWalk',
162-
'GARCH11',
163-
'SkewNormal',
164-
'Mixture',
165-
'NormalMixture',
166-
'Triangular',
167-
'DiscreteWeibull',
168-
'Gumbel',
169-
'Logistic',
170-
'LogitNormal',
171-
'Interpolated',
172-
'Bound',
173-
'Rice',
174-
'Moyal',
175-
'Simulator',
176-
'fast_sample_posterior_predictive'
177-
]
103+
__all__ = [
104+
"Uniform",
105+
"Flat",
106+
"HalfFlat",
107+
"TruncatedNormal",
108+
"Normal",
109+
"Beta",
110+
"Kumaraswamy",
111+
"Exponential",
112+
"Laplace",
113+
"StudentT",
114+
"Cauchy",
115+
"HalfCauchy",
116+
"Gamma",
117+
"Weibull",
118+
"Bound",
119+
"Lognormal",
120+
"HalfStudentT",
121+
"ChiSquared",
122+
"HalfNormal",
123+
"Wald",
124+
"Pareto",
125+
"InverseGamma",
126+
"ExGaussian",
127+
"VonMises",
128+
"Binomial",
129+
"BetaBinomial",
130+
"Bernoulli",
131+
"Poisson",
132+
"NegativeBinomial",
133+
"ConstantDist",
134+
"Constant",
135+
"ZeroInflatedPoisson",
136+
"ZeroInflatedNegativeBinomial",
137+
"ZeroInflatedBinomial",
138+
"DiscreteUniform",
139+
"Geometric",
140+
"Categorical",
141+
"OrderedLogistic",
142+
"DensityDist",
143+
"Distribution",
144+
"Continuous",
145+
"Discrete",
146+
"NoDistribution",
147+
"TensorType",
148+
"MvNormal",
149+
"MatrixNormal",
150+
"KroneckerNormal",
151+
"MvStudentT",
152+
"Dirichlet",
153+
"Multinomial",
154+
"Wishart",
155+
"WishartBartlett",
156+
"LKJCholeskyCov",
157+
"LKJCorr",
158+
"AR1",
159+
"AR",
160+
"GaussianRandomWalk",
161+
"MvGaussianRandomWalk",
162+
"MvStudentTRandomWalk",
163+
"GARCH11",
164+
"SkewNormal",
165+
"Mixture",
166+
"NormalMixture",
167+
"Triangular",
168+
"DiscreteWeibull",
169+
"Gumbel",
170+
"Logistic",
171+
"LogitNormal",
172+
"Interpolated",
173+
"Bound",
174+
"Rice",
175+
"Moyal",
176+
"Simulator",
177+
"fast_sample_posterior_predictive",
178+
]

pymc3/distributions/bound.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,13 @@ def _random(self, lower, upper, point=None, size=None):
8282
upper = np.asarray(upper)
8383
if lower.size > 1 or upper.size > 1:
8484
raise ValueError(
85-
"Drawing samples from distributions with "
86-
"array-valued bounds is not supported."
85+
"Drawing samples from distributions with array-valued bounds is not supported."
8786
)
8887
total_size = np.prod(size).astype(np.int)
8988
samples = []
9089
s = 0
9190
while s < total_size:
92-
sample = np.atleast_1d(
93-
self._wrapped.random(point=point, size=total_size)
94-
).flatten()
91+
sample = np.atleast_1d(self._wrapped.random(point=point, size=total_size)).flatten()
9592

9693
select = sample[np.logical_and(sample >= lower, sample <= upper)]
9794
samples.append(select)
@@ -128,7 +125,7 @@ def random(self, point=None, size=None):
128125
upper,
129126
dist_shape=self.shape,
130127
size=size,
131-
not_broadcast_kwargs={'point': point},
128+
not_broadcast_kwargs={"point": point},
132129
)
133130
elif self.lower is not None:
134131
lower = draw_values([self.lower], point=point, size=size)
@@ -138,7 +135,7 @@ def random(self, point=None, size=None):
138135
np.inf,
139136
dist_shape=self.shape,
140137
size=size,
141-
not_broadcast_kwargs={'point': point},
138+
not_broadcast_kwargs={"point": point},
142139
)
143140
else:
144141
upper = draw_values([self.upper], point=point, size=size)
@@ -148,7 +145,7 @@ def random(self, point=None, size=None):
148145
upper,
149146
dist_shape=self.shape,
150147
size=size,
151-
not_broadcast_kwargs={'point': point},
148+
not_broadcast_kwargs={"point": point},
152149
)
153150

154151

@@ -168,9 +165,7 @@ def __init__(self, distribution, lower, upper, transform="infer", *args, **kwarg
168165
if lower is not None:
169166
default = lower + 1
170167

171-
super().__init__(
172-
distribution, lower, upper, default, *args, transform=transform, **kwargs
173-
)
168+
super().__init__(distribution, lower, upper, default, *args, transform=transform, **kwargs)
174169

175170

176171
class _ContinuousBounded(_Bounded, Continuous):
@@ -215,9 +210,7 @@ def __init__(self, distribution, lower, upper, transform="infer", *args, **kwarg
215210
else:
216211
default = None
217212

218-
super().__init__(
219-
distribution, lower, upper, default, *args, transform=transform, **kwargs
220-
)
213+
super().__init__(distribution, lower, upper, default, *args, transform=transform, **kwargs)
221214

222215

223216
class Bound:
@@ -283,23 +276,11 @@ def __call__(self, name, *args, **kwargs):
283276
transform = kwargs.pop("transform", "infer")
284277
if issubclass(self.distribution, Continuous):
285278
return _ContinuousBounded(
286-
name,
287-
self.distribution,
288-
self.lower,
289-
self.upper,
290-
transform,
291-
*args,
292-
**kwargs
279+
name, self.distribution, self.lower, self.upper, transform, *args, **kwargs
293280
)
294281
elif issubclass(self.distribution, Discrete):
295282
return _DiscreteBounded(
296-
name,
297-
self.distribution,
298-
self.lower,
299-
self.upper,
300-
transform,
301-
*args,
302-
**kwargs
283+
name, self.distribution, self.lower, self.upper, transform, *args, **kwargs
303284
)
304285
else:
305286
raise ValueError("Distribution is neither continuous nor discrete.")
@@ -311,8 +292,6 @@ def dist(self, *args, **kwargs):
311292
)
312293

313294
elif issubclass(self.distribution, Discrete):
314-
return _DiscreteBounded.dist(
315-
self.distribution, self.lower, self.upper, *args, **kwargs
316-
)
295+
return _DiscreteBounded.dist(self.distribution, self.lower, self.upper, *args, **kwargs)
317296
else:
318297
raise ValueError("Distribution is neither continuous nor discrete.")

pymc3/distributions/shape_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def _check_shape_type(shape):
6464
raise TypeError(f"Value {s} is not a valid integer")
6565
out.append(o)
6666
except Exception:
67-
raise TypeError(
68-
f"Supplied value {shape} does not represent a valid shape"
69-
)
67+
raise TypeError(f"Supplied value {shape} does not represent a valid shape")
7068
return tuple(out)
7169

7270

@@ -172,10 +170,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
172170
shapes = [_check_shape_type(s) for s in shapes]
173171
_size = to_tuple(size)
174172
# samples shapes without the size prepend
175-
sp_shapes = [
176-
s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s
177-
for s in shapes
178-
]
173+
sp_shapes = [s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s for s in shapes]
179174
try:
180175
broadcast_shape = shapes_broadcasting(*sp_shapes, raise_exception=True)
181176
except ValueError:
@@ -277,8 +272,7 @@ def get_broadcastable_dist_samples(
277272
out_shape = broadcast_dist_samples_shape(p_shapes, size=size)
278273
# samples shapes without the size prepend
279274
sp_shapes = [
280-
s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s
281-
for s in p_shapes
275+
s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s for s in p_shapes
282276
]
283277
broadcast_shape = shapes_broadcasting(*sp_shapes, raise_exception=True)
284278
broadcastable_samples = []

pymc3/distributions/simulator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
130130
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"
131131

132132

133-
134133
def identity(x):
135134
"""Identity function, used as a summary statistics."""
136135
return x

pymc3/distributions/special.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from theano.scalar.basic_scipy import GammaLn, Psi
1818
from theano import scalar
1919

20-
__all__ = ['gammaln', 'multigammaln', 'psi', 'log_i0']
20+
__all__ = ["gammaln", "multigammaln", "psi", "log_i0"]
2121

22-
scalar_gammaln = GammaLn(scalar.upgrade_to_float, name='scalar_gammaln')
23-
gammaln = tt.Elemwise(scalar_gammaln, name='gammaln')
22+
scalar_gammaln = GammaLn(scalar.upgrade_to_float, name="scalar_gammaln")
23+
gammaln = tt.Elemwise(scalar_gammaln, name="gammaln")
2424

2525

2626
def multigammaln(a, p):
@@ -33,21 +33,33 @@ def multigammaln(a, p):
3333
degrees of freedom. p > 0
3434
"""
3535
i = tt.arange(1, p + 1)
36-
return (p * (p - 1) * tt.log(np.pi) / 4.
37-
+ tt.sum(gammaln(a + (1. - i) / 2.), axis=0))
36+
return p * (p - 1) * tt.log(np.pi) / 4.0 + tt.sum(gammaln(a + (1.0 - i) / 2.0), axis=0)
3837

3938

4039
def log_i0(x):
4140
"""
4241
Calculates the logarithm of the 0 order modified Bessel function of the first kind""
4342
"""
44-
return tt.switch(tt.lt(x, 5), tt.log1p(x**2. / 4. + x**4. / 64. + x**6. / 2304.
45-
+ x**8. / 147456. + x**10. / 14745600.
46-
+ x**12. / 2123366400.),
47-
x - 0.5 * tt.log(2. * np.pi * x) + tt.log1p(1. / (8. * x)
48-
+ 9. / (128. * x**2.) + 225. / (3072. * x**3.)
49-
+ 11025. / (98304. * x**4.)))
43+
return tt.switch(
44+
tt.lt(x, 5),
45+
tt.log1p(
46+
x ** 2.0 / 4.0
47+
+ x ** 4.0 / 64.0
48+
+ x ** 6.0 / 2304.0
49+
+ x ** 8.0 / 147456.0
50+
+ x ** 10.0 / 14745600.0
51+
+ x ** 12.0 / 2123366400.0
52+
),
53+
x
54+
- 0.5 * tt.log(2.0 * np.pi * x)
55+
+ tt.log1p(
56+
1.0 / (8.0 * x)
57+
+ 9.0 / (128.0 * x ** 2.0)
58+
+ 225.0 / (3072.0 * x ** 3.0)
59+
+ 11025.0 / (98304.0 * x ** 4.0)
60+
),
61+
)
5062

5163

52-
scalar_psi = Psi(scalar.upgrade_to_float, name='scalar_psi')
53-
psi = tt.Elemwise(scalar_psi, name='psi')
64+
scalar_psi = Psi(scalar.upgrade_to_float, name="scalar_psi")
65+
psi = tt.Elemwise(scalar_psi, name="psi")

0 commit comments

Comments
 (0)