Skip to content

Commit 890ae74

Browse files
committed
Cleaned up model.py, made it comply with pep8, and fixed lint error on distribution.py.
1 parent 339828d commit 890ae74

File tree

2 files changed

+74
-103
lines changed

2 files changed

+74
-103
lines changed

pymc3/distributions/distribution.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import theano
66
from ..memoize import memoize
77
from ..model import (
8-
Model, modelcontext, FreeRV, ObservedRV, MultiObservedRV,
8+
Model, modelcontext, FreeRV, ObservedRV,
99
not_shared_or_constant_variable, DependenceDAG
1010
)
1111
from ..vartypes import string_types
@@ -35,12 +35,14 @@ def __new__(cls, name, *args, **kwargs):
3535
if isinstance(name, string_types):
3636
data = kwargs.pop('observed', None)
3737
if isinstance(data, ObservedRV) or isinstance(data, FreeRV):
38-
raise TypeError("observed needs to be data but got: {}".format(type(data)))
38+
raise TypeError("observed needs to be data but got: {}".
39+
format(type(data)))
3940
total_size = kwargs.pop('total_size', None)
4041
dist = cls.dist(*args, **kwargs)
4142
return model.Var(name, dist, data, total_size)
4243
else:
43-
raise TypeError("Name needs to be a string but got: {}".format(name))
44+
raise TypeError("Name needs to be a string but got: {}".
45+
format(name))
4446

4547
def __getnewargs__(self):
4648
return _Unpickling,
@@ -64,12 +66,14 @@ def __init__(self, shape, dtype, testval=None, defaults=(),
6466
self.conditional_on = None
6567

6668
def default(self):
67-
return np.asarray(self.get_test_val(self.testval, self.defaults), self.dtype)
69+
return np.asarray(self.get_test_val(self.testval, self.defaults),
70+
self.dtype)
6871

6972
def get_test_val(self, val, defaults):
7073
if val is None:
7174
for v in defaults:
72-
if hasattr(self, v) and np.all(np.isfinite(self.getattr_value(v))):
75+
if (hasattr(self, v) and
76+
np.all(np.isfinite(self.getattr_value(v)))):
7377
return self.getattr_value(v)
7478
else:
7579
return self.getattr_value(val)
@@ -132,7 +136,8 @@ class NoDistribution(Distribution):
132136
def __init__(self, shape, dtype, testval=None, defaults=(),
133137
transform=None, parent_dist=None, *args, **kwargs):
134138
super(NoDistribution, self).__init__(shape=shape, dtype=dtype,
135-
testval=testval, defaults=defaults,
139+
testval=testval,
140+
defaults=defaults,
136141
*args, **kwargs)
137142
self.parent_dist = parent_dist
138143

@@ -161,7 +166,8 @@ def __init__(self, shape=(), dtype=None, defaults=('mode',),
161166
else:
162167
dtype = 'int64'
163168
if dtype != 'int16' and dtype != 'int64':
164-
raise TypeError('Discrete classes expect dtype to be int16 or int64.')
169+
raise TypeError('Discrete classes expect dtype to be int16 or '
170+
'int64.')
165171

166172
if kwargs.get('transform', None) is not None:
167173
raise ValueError("Transformations for discrete distributions "
@@ -174,7 +180,8 @@ def __init__(self, shape=(), dtype=None, defaults=('mode',),
174180
class Continuous(Distribution):
175181
"""Base class for continuous distributions"""
176182

177-
def __init__(self, shape=(), dtype=None, defaults=('median', 'mean', 'mode'),
183+
def __init__(self, shape=(), dtype=None,
184+
defaults=('median', 'mean', 'mode'),
178185
*args, **kwargs):
179186
if dtype is None:
180187
dtype = theano.config.floatX
@@ -195,12 +202,15 @@ class DensityDist(Distribution):
195202
with pm.Model():
196203
mu = pm.Normal('mu',0,1)
197204
normal_dist = pm.Normal.dist(mu, 1)
198-
pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100), random=normal_dist.random)
205+
pm.DensityDist('density_dist', normal_dist.logp,
206+
observed=np.random.randn(100),
207+
random=normal_dist.random)
199208
trace = pm.sample(100)
200209
201210
"""
202211

203-
def __init__(self, logp, shape=(), dtype=None, testval=0, random=None, *args, **kwargs):
212+
def __init__(self, logp, shape=(), dtype=None, testval=0, random=None,
213+
*args, **kwargs):
204214
if dtype is None:
205215
dtype = theano.config.floatX
206216
super(DensityDist, self).__init__(
@@ -213,7 +223,8 @@ def random(self, *args, **kwargs):
213223
return self.rand(*args, **kwargs)
214224
else:
215225
raise ValueError("Distribution was not passed any random method "
216-
"Define a custom random method and pass it as kwarg random")
226+
"Define a custom random method and pass it as "
227+
"kwarg random")
217228

218229

219230
def draw_values(params, point=None, size=None, model=None):
@@ -462,6 +473,7 @@ def to_tuple(shape):
462473
shape = tuple(shape)
463474
return shape
464475

476+
465477
def _is_one_d(dist_shape):
466478
if hasattr(dist_shape, 'dshape') and dist_shape.dshape in ((), (0,), (1,)):
467479
return True
@@ -471,6 +483,7 @@ def _is_one_d(dist_shape):
471483
return True
472484
return False
473485

486+
474487
def generate_samples(generator, *args, **kwargs):
475488
"""Generate samples from the distribution of a random variable.
476489

0 commit comments

Comments
 (0)