Skip to content

Commit 14d5421

Browse files
committed
black
1 parent 3cb4570 commit 14d5421

File tree

151 files changed

+11443
-7663
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

151 files changed

+11443
-7663
lines changed

pymc3/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
from .distributions import *
66
from .glm import *
77
from . import gp
8-
from .math import logaddexp, logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
8+
from .math import (
9+
logaddexp,
10+
logsumexp,
11+
logit,
12+
invlogit,
13+
expand_packed_triangular,
14+
probit,
15+
invprobit,
16+
)
917
from .model import *
1018
from .model_graph import model_to_graphviz
1119
from .stats import *
@@ -28,7 +36,8 @@
2836
from .data import *
2937

3038
import logging
31-
_log = logging.getLogger('pymc3')
39+
40+
_log = logging.getLogger("pymc3")
3241
if not logging.root.handlers:
3342
_log.setLevel(logging.INFO)
3443
handler = logging.StreamHandler()

pymc3/backends/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,8 @@
122122
from ..backends.sqlite import SQLite
123123
from ..backends.hdf5 import HDF5
124124

125-
_shortcuts = {'text': {'backend': Text,
126-
'name': 'mcmc'},
127-
'sqlite': {'backend': SQLite,
128-
'name': 'mcmc.sqlite'},
129-
'hdf5': {'backend': HDF5,
130-
'name': 'mcmc.hdf5'}}
125+
_shortcuts = {
126+
"text": {"backend": Text, "name": "mcmc"},
127+
"sqlite": {"backend": SQLite, "name": "mcmc.sqlite"},
128+
"hdf5": {"backend": HDF5, "name": "mcmc.hdf5"},
129+
}

pymc3/backends/base.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ..model import modelcontext
1414
from .report import SamplerReport, merge_reports
1515

16-
logger = logging.getLogger('pymc3')
16+
logger = logging.getLogger("pymc3")
1717

1818

1919
class BackendError(Exception):
@@ -58,10 +58,8 @@ def __init__(self, name, model=None, vars=None, test_point=None):
5858
test_point_.update(test_point)
5959
test_point = test_point_
6060
var_values = list(zip(self.varnames, self.fn(test_point)))
61-
self.var_shapes = {var: value.shape
62-
for var, value in var_values}
63-
self.var_dtypes = {var: value.dtype
64-
for var, value in var_values}
61+
self.var_shapes = {var: value.shape for var, value in var_values}
62+
self.var_dtypes = {var: value.dtype for var, value in var_values}
6563
self.chain = None
6664
self._is_base_setup = False
6765
self.sampler_vars = None
@@ -87,8 +85,9 @@ def _set_sampler_vars(self, sampler_vars):
8785
for stats in sampler_vars:
8886
for key, dtype in stats.items():
8987
if dtypes.setdefault(key, dtype) != dtype:
90-
raise ValueError("Sampler statistic %s appears with "
91-
"different types." % key)
88+
raise ValueError(
89+
"Sampler statistic %s appears with " "different types." % key
90+
)
9291

9392
self.sampler_vars = sampler_vars
9493

@@ -137,7 +136,7 @@ def __getitem__(self, idx):
137136
try:
138137
return self.point(int(idx))
139138
except (ValueError, TypeError): # Passed variable or variable name.
140-
raise ValueError('Can only index with slice or integer')
139+
raise ValueError("Can only index with slice or integer")
141140

142141
def __len__(self):
143142
raise NotImplementedError
@@ -181,13 +180,14 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
181180
if sampler_idx is not None:
182181
return self._get_sampler_stats(varname, sampler_idx, burn, thin)
183182

184-
sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
185-
if varname in s]
183+
sampler_idxs = [i for i, s in enumerate(self.sampler_vars) if varname in s]
186184
if not sampler_idxs:
187185
raise KeyError("Unknown sampler stat %s" % varname)
188186

189-
vals = np.stack([self._get_sampler_stats(varname, i, burn, thin)
190-
for i in sampler_idxs], axis=-1)
187+
vals = np.stack(
188+
[self._get_sampler_stats(varname, i, burn, thin) for i in sampler_idxs],
189+
axis=-1,
190+
)
191191
if vals.shape[-1] == 1:
192192
return vals[..., 0]
193193
else:
@@ -267,13 +267,14 @@ def __init__(self, straces):
267267

268268
self._report = SamplerReport()
269269
for strace in straces:
270-
if hasattr(strace, '_warnings'):
270+
if hasattr(strace, "_warnings"):
271271
self._report._add_warnings(strace._warnings, strace.chain)
272272

273273
def __repr__(self):
274-
template = '<{}: {} chains, {} iterations, {} variables>'
275-
return template.format(self.__class__.__name__,
276-
self.nchains, len(self), len(self.varnames))
274+
template = "<{}: {} chains, {} iterations, {} variables>"
275+
return template.format(
276+
self.__class__.__name__, self.nchains, len(self), len(self.varnames)
277+
)
277278

278279
@property
279280
def nchains(self):
@@ -310,16 +311,26 @@ def __getitem__(self, idx):
310311
var = str(var)
311312
if var in self.varnames:
312313
if var in self.stat_names:
313-
warnings.warn("Attribute access on a trace object is ambigous. "
314-
"Sampler statistic and model variable share a name. Use "
315-
"trace.get_values or trace.get_sampler_stats.")
314+
warnings.warn(
315+
"Attribute access on a trace object is ambigous. "
316+
"Sampler statistic and model variable share a name. Use "
317+
"trace.get_values or trace.get_sampler_stats."
318+
)
316319
return self.get_values(var, burn=burn, thin=thin)
317320
if var in self.stat_names:
318321
return self.get_sampler_stats(var, burn=burn, thin=thin)
319322
raise KeyError("Unknown variable %s" % var)
320323

321-
_attrs = set(['_straces', 'varnames', 'chains', 'stat_names',
322-
'supports_sampler_stats', '_report'])
324+
_attrs = set(
325+
[
326+
"_straces",
327+
"varnames",
328+
"chains",
329+
"stat_names",
330+
"supports_sampler_stats",
331+
"_report",
332+
]
333+
)
323334

324335
def __getattr__(self, name):
325336
# Avoid infinite recursion when called before __init__
@@ -330,14 +341,17 @@ def __getattr__(self, name):
330341
name = str(name)
331342
if name in self.varnames:
332343
if name in self.stat_names:
333-
warnings.warn("Attribute access on a trace object is ambigous. "
334-
"Sampler statistic and model variable share a name. Use "
335-
"trace.get_values or trace.get_sampler_stats.")
344+
warnings.warn(
345+
"Attribute access on a trace object is ambigous. "
346+
"Sampler statistic and model variable share a name. Use "
347+
"trace.get_values or trace.get_sampler_stats."
348+
)
336349
return self.get_values(name)
337350
if name in self.stat_names:
338351
return self.get_sampler_stats(name)
339-
raise AttributeError("'{}' object has no attribute '{}'".format(
340-
type(self).__name__, name))
352+
raise AttributeError(
353+
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
354+
)
341355

342356
def __len__(self):
343357
chain = self.chains[-1]
@@ -392,10 +406,12 @@ def add_values(self, vals, overwrite=False):
392406
l_samples = len(self) * len(self.chains)
393407
l_v = len(v)
394408
if l_v != l_samples:
395-
warnings.warn("The length of the values you are trying to "
396-
"add ({}) does not match the number ({}) of "
397-
"total samples in the trace "
398-
"(chains * iterations)".format(l_v, l_samples))
409+
warnings.warn(
410+
"The length of the values you are trying to "
411+
"add ({}) does not match the number ({}) of "
412+
"total samples in the trace "
413+
"(chains * iterations)".format(l_v, l_samples)
414+
)
399415

400416
v = np.squeeze(v.reshape(len(chains), len(self), -1))
401417

@@ -424,8 +440,9 @@ def remove_values(self, name):
424440
chain.vars.remove(va)
425441
del chain.samples[name]
426442

427-
def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
428-
squeeze=True):
443+
def get_values(
444+
self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True
445+
):
429446
"""Get values from traces.
430447
431448
Parameters
@@ -452,14 +469,16 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
452469
chains = self.chains
453470
varname = str(varname)
454471
try:
455-
results = [self._straces[chain].get_values(varname, burn, thin)
456-
for chain in chains]
472+
results = [
473+
self._straces[chain].get_values(varname, burn, thin) for chain in chains
474+
]
457475
except TypeError: # Single chain passed.
458476
results = [self._straces[chains].get_values(varname, burn, thin)]
459477
return _squeeze_cat(results, combine, squeeze)
460478

461-
def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
462-
chains=None, squeeze=True):
479+
def get_sampler_stats(
480+
self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True
481+
):
463482
"""Get sampler statistics from the trace.
464483
465484
Parameters
@@ -487,8 +506,10 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
487506
except TypeError:
488507
chains = [chains]
489508

490-
results = [self._straces[chain].get_sampler_stats(varname, None, burn, thin)
491-
for chain in chains]
509+
results = [
510+
self._straces[chain].get_sampler_stats(varname, None, burn, thin)
511+
for chain in chains
512+
]
492513
return _squeeze_cat(results, combine, squeeze)
493514

494515
def _slice(self, slice):

pymc3/backends/hdf5.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import h5py
33
from contextlib import contextmanager
44

5+
56
@contextmanager
67
def activator(instance):
78
if isinstance(instance.hdf5_file, h5py.File):
89
if instance.hdf5_file.id: # if file is open, keep open
910
yield
1011
return
1112
# if file is closed/not referenced: open, do job, then close
12-
instance.hdf5_file = h5py.File(instance.name, 'a')
13+
instance.hdf5_file = h5py.File(instance.name, "a")
1314
yield
1415
instance.hdf5_file.close()
1516
return
@@ -50,21 +51,21 @@ def activate_file(self):
5051
@property
5152
def samples(self):
5253
g = self.hdf5_file.require_group(str(self.chain))
53-
if 'name' not in g.attrs:
54-
g.attrs['name'] = self.chain
55-
return g.require_group('samples')
54+
if "name" not in g.attrs:
55+
g.attrs["name"] = self.chain
56+
return g.require_group("samples")
5657

5758
@property
5859
def stats(self):
5960
g = self.hdf5_file.require_group(str(self.chain))
60-
if 'name' not in g.attrs:
61-
g.attrs['name'] = self.chain
62-
return g.require_group('stats')
61+
if "name" not in g.attrs:
62+
g.attrs["name"] = self.chain
63+
return g.require_group("stats")
6364

6465
@property
6566
def chains(self):
6667
with self.activate_file:
67-
return [v.attrs['name'] for v in self.hdf5_file.values()]
68+
return [v.attrs["name"] for v in self.hdf5_file.values()]
6869

6970
@property
7071
def is_new_file(self):
@@ -84,19 +85,19 @@ def nchains(self):
8485
@property
8586
def records_stats(self):
8687
with self.activate_file:
87-
return self.hdf5_file.attrs['records_stats']
88+
return self.hdf5_file.attrs["records_stats"]
8889

8990
@records_stats.setter
9091
def records_stats(self, v):
9192
with self.activate_file:
92-
self.hdf5_file.attrs['records_stats'] = bool(v)
93+
self.hdf5_file.attrs["records_stats"] = bool(v)
9394

9495
def _resize(self, n):
9596
for v in self.samples.values():
9697
v.resize(n, axis=0)
9798
for key, group in self.stats.items():
9899
for statds in group.values():
99-
statds.resize((n, ))
100+
statds.resize((n,))
100101

101102
@property
102103
def sampler_vars(self):
@@ -123,10 +124,15 @@ def sampler_vars(self, values):
123124
if not data.keys(): # no pre-recorded stats
124125
for varname, dtype in sampler.items():
125126
if varname not in data:
126-
data.create_dataset(varname, (self.draws,), dtype=dtype, maxshape=(None,))
127+
data.create_dataset(
128+
varname, (self.draws,), dtype=dtype, maxshape=(None,)
129+
)
127130
elif data.keys() != sampler.keys():
128131
raise ValueError(
129-
"Sampler vars can't change, names incompatible: {} != {}".format(data.keys(), sampler.keys()))
132+
"Sampler vars can't change, names incompatible: {} != {}".format(
133+
data.keys(), sampler.keys()
134+
)
135+
)
130136
self.records_stats = True
131137

132138
def setup(self, draws, chain, sampler_vars=None):
@@ -146,16 +152,18 @@ def setup(self, draws, chain, sampler_vars=None):
146152
with self.activate_file:
147153
for varname, shape in self.var_shapes.items():
148154
if varname not in self.samples:
149-
self.samples.create_dataset(name=varname, shape=(draws, ) + shape,
150-
dtype=self.var_dtypes[varname],
151-
maxshape=(None, ) + shape)
155+
self.samples.create_dataset(
156+
name=varname,
157+
shape=(draws,) + shape,
158+
dtype=self.var_dtypes[varname],
159+
maxshape=(None,) + shape,
160+
)
152161
self.draw_idx = len(self)
153162
self.draws = self.draw_idx + draws
154163
self._set_sampler_vars(sampler_vars)
155164
self._is_base_setup = True
156165
self._resize(self.draws)
157166

158-
159167
def close(self):
160168
with self.activate_file:
161169
if self.draw_idx == self.draws:
@@ -190,8 +198,9 @@ def _slice(self, idx):
190198
start, stop, step = idx.indices(len(self))
191199
sliced = ndarray.NDArray(model=self.model, vars=self.vars)
192200
sliced.chain = self.chain
193-
sliced.samples = {v: self.samples[v][start:stop:step]
194-
for v in self.varnames}
201+
sliced.samples = {
202+
v: self.samples[v][start:stop:step] for v in self.varnames
203+
}
195204
sliced.draw_idx = (stop - start) // step
196205
return sliced
197206

0 commit comments

Comments
 (0)