Skip to content

black codestyle #3239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from .distributions import *
from .glm import *
from . import gp
from .math import logaddexp, logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
from .math import (
logaddexp,
logsumexp,
logit,
invlogit,
expand_packed_triangular,
probit,
invprobit,
)
from .model import *
from .model_graph import model_to_graphviz
from .stats import *
Expand All @@ -28,7 +36,8 @@
from .data import *

import logging
_log = logging.getLogger('pymc3')

_log = logging.getLogger("pymc3")
if not logging.root.handlers:
_log.setLevel(logging.INFO)
handler = logging.StreamHandler()
Expand Down
11 changes: 5 additions & 6 deletions pymc3/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,8 @@
from ..backends.sqlite import SQLite
from ..backends.hdf5 import HDF5

_shortcuts = {'text': {'backend': Text,
'name': 'mcmc'},
'sqlite': {'backend': SQLite,
'name': 'mcmc.sqlite'},
'hdf5': {'backend': HDF5,
'name': 'mcmc.hdf5'}}
_shortcuts = {
"text": {"backend": Text, "name": "mcmc"},
"sqlite": {"backend": SQLite, "name": "mcmc.sqlite"},
"hdf5": {"backend": HDF5, "name": "mcmc.hdf5"},
}
97 changes: 59 additions & 38 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..model import modelcontext
from .report import SamplerReport, merge_reports

logger = logging.getLogger('pymc3')
logger = logging.getLogger("pymc3")


class BackendError(Exception):
Expand Down Expand Up @@ -58,10 +58,8 @@ def __init__(self, name, model=None, vars=None, test_point=None):
test_point_.update(test_point)
test_point = test_point_
var_values = list(zip(self.varnames, self.fn(test_point)))
self.var_shapes = {var: value.shape
for var, value in var_values}
self.var_dtypes = {var: value.dtype
for var, value in var_values}
self.var_shapes = {var: value.shape for var, value in var_values}
self.var_dtypes = {var: value.dtype for var, value in var_values}
self.chain = None
self._is_base_setup = False
self.sampler_vars = None
Expand All @@ -87,8 +85,9 @@ def _set_sampler_vars(self, sampler_vars):
for stats in sampler_vars:
for key, dtype in stats.items():
if dtypes.setdefault(key, dtype) != dtype:
raise ValueError("Sampler statistic %s appears with "
"different types." % key)
raise ValueError(
"Sampler statistic %s appears with " "different types." % key
)

self.sampler_vars = sampler_vars

Expand Down Expand Up @@ -137,7 +136,7 @@ def __getitem__(self, idx):
try:
return self.point(int(idx))
except (ValueError, TypeError): # Passed variable or variable name.
raise ValueError('Can only index with slice or integer')
raise ValueError("Can only index with slice or integer")

def __len__(self):
raise NotImplementedError
Expand Down Expand Up @@ -181,13 +180,14 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
if sampler_idx is not None:
return self._get_sampler_stats(varname, sampler_idx, burn, thin)

sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
if varname in s]
sampler_idxs = [i for i, s in enumerate(self.sampler_vars) if varname in s]
if not sampler_idxs:
raise KeyError("Unknown sampler stat %s" % varname)

vals = np.stack([self._get_sampler_stats(varname, i, burn, thin)
for i in sampler_idxs], axis=-1)
vals = np.stack(
[self._get_sampler_stats(varname, i, burn, thin) for i in sampler_idxs],
axis=-1,
)
if vals.shape[-1] == 1:
return vals[..., 0]
else:
Expand Down Expand Up @@ -267,13 +267,14 @@ def __init__(self, straces):

self._report = SamplerReport()
for strace in straces:
if hasattr(strace, '_warnings'):
if hasattr(strace, "_warnings"):
self._report._add_warnings(strace._warnings, strace.chain)

def __repr__(self):
template = '<{}: {} chains, {} iterations, {} variables>'
return template.format(self.__class__.__name__,
self.nchains, len(self), len(self.varnames))
template = "<{}: {} chains, {} iterations, {} variables>"
return template.format(
self.__class__.__name__, self.nchains, len(self), len(self.varnames)
)

@property
def nchains(self):
Expand Down Expand Up @@ -310,16 +311,26 @@ def __getitem__(self, idx):
var = str(var)
if var in self.varnames:
if var in self.stat_names:
warnings.warn("Attribute access on a trace object is ambigous. "
"Sampler statistic and model variable share a name. Use "
"trace.get_values or trace.get_sampler_stats.")
warnings.warn(
"Attribute access on a trace object is ambigous. "
"Sampler statistic and model variable share a name. Use "
"trace.get_values or trace.get_sampler_stats."
)
return self.get_values(var, burn=burn, thin=thin)
if var in self.stat_names:
return self.get_sampler_stats(var, burn=burn, thin=thin)
raise KeyError("Unknown variable %s" % var)

_attrs = set(['_straces', 'varnames', 'chains', 'stat_names',
'supports_sampler_stats', '_report'])
_attrs = set(
[
"_straces",
"varnames",
"chains",
"stat_names",
"supports_sampler_stats",
"_report",
]
)

def __getattr__(self, name):
# Avoid infinite recursion when called before __init__
Expand All @@ -330,14 +341,17 @@ def __getattr__(self, name):
name = str(name)
if name in self.varnames:
if name in self.stat_names:
warnings.warn("Attribute access on a trace object is ambigous. "
"Sampler statistic and model variable share a name. Use "
"trace.get_values or trace.get_sampler_stats.")
warnings.warn(
"Attribute access on a trace object is ambigous. "
"Sampler statistic and model variable share a name. Use "
"trace.get_values or trace.get_sampler_stats."
)
return self.get_values(name)
if name in self.stat_names:
return self.get_sampler_stats(name)
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)

def __len__(self):
chain = self.chains[-1]
Expand Down Expand Up @@ -392,10 +406,12 @@ def add_values(self, vals, overwrite=False):
l_samples = len(self) * len(self.chains)
l_v = len(v)
if l_v != l_samples:
warnings.warn("The length of the values you are trying to "
"add ({}) does not match the number ({}) of "
"total samples in the trace "
"(chains * iterations)".format(l_v, l_samples))
warnings.warn(
"The length of the values you are trying to "
"add ({}) does not match the number ({}) of "
"total samples in the trace "
"(chains * iterations)".format(l_v, l_samples)
)

v = np.squeeze(v.reshape(len(chains), len(self), -1))

Expand Down Expand Up @@ -424,8 +440,9 @@ def remove_values(self, name):
chain.vars.remove(va)
del chain.samples[name]

def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
squeeze=True):
def get_values(
self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True
):
"""Get values from traces.

Parameters
Expand All @@ -452,14 +469,16 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
chains = self.chains
varname = str(varname)
try:
results = [self._straces[chain].get_values(varname, burn, thin)
for chain in chains]
results = [
self._straces[chain].get_values(varname, burn, thin) for chain in chains
]
except TypeError: # Single chain passed.
results = [self._straces[chains].get_values(varname, burn, thin)]
return _squeeze_cat(results, combine, squeeze)

def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
chains=None, squeeze=True):
def get_sampler_stats(
self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True
):
"""Get sampler statistics from the trace.

Parameters
Expand Down Expand Up @@ -487,8 +506,10 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
except TypeError:
chains = [chains]

results = [self._straces[chain].get_sampler_stats(varname, None, burn, thin)
for chain in chains]
results = [
self._straces[chain].get_sampler_stats(varname, None, burn, thin)
for chain in chains
]
return _squeeze_cat(results, combine, squeeze)

def _slice(self, slice):
Expand Down
47 changes: 28 additions & 19 deletions pymc3/backends/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import h5py
from contextlib import contextmanager


@contextmanager
def activator(instance):
if isinstance(instance.hdf5_file, h5py.File):
if instance.hdf5_file.id: # if file is open, keep open
yield
return
# if file is closed/not referenced: open, do job, then close
instance.hdf5_file = h5py.File(instance.name, 'a')
instance.hdf5_file = h5py.File(instance.name, "a")
yield
instance.hdf5_file.close()
return
Expand Down Expand Up @@ -50,21 +51,21 @@ def activate_file(self):
@property
def samples(self):
g = self.hdf5_file.require_group(str(self.chain))
if 'name' not in g.attrs:
g.attrs['name'] = self.chain
return g.require_group('samples')
if "name" not in g.attrs:
g.attrs["name"] = self.chain
return g.require_group("samples")

@property
def stats(self):
g = self.hdf5_file.require_group(str(self.chain))
if 'name' not in g.attrs:
g.attrs['name'] = self.chain
return g.require_group('stats')
if "name" not in g.attrs:
g.attrs["name"] = self.chain
return g.require_group("stats")

@property
def chains(self):
with self.activate_file:
return [v.attrs['name'] for v in self.hdf5_file.values()]
return [v.attrs["name"] for v in self.hdf5_file.values()]

@property
def is_new_file(self):
Expand All @@ -84,19 +85,19 @@ def nchains(self):
@property
def records_stats(self):
with self.activate_file:
return self.hdf5_file.attrs['records_stats']
return self.hdf5_file.attrs["records_stats"]

@records_stats.setter
def records_stats(self, v):
with self.activate_file:
self.hdf5_file.attrs['records_stats'] = bool(v)
self.hdf5_file.attrs["records_stats"] = bool(v)

def _resize(self, n):
for v in self.samples.values():
v.resize(n, axis=0)
for key, group in self.stats.items():
for statds in group.values():
statds.resize((n, ))
statds.resize((n,))

@property
def sampler_vars(self):
Expand All @@ -123,10 +124,15 @@ def sampler_vars(self, values):
if not data.keys(): # no pre-recorded stats
for varname, dtype in sampler.items():
if varname not in data:
data.create_dataset(varname, (self.draws,), dtype=dtype, maxshape=(None,))
data.create_dataset(
varname, (self.draws,), dtype=dtype, maxshape=(None,)
)
elif data.keys() != sampler.keys():
raise ValueError(
"Sampler vars can't change, names incompatible: {} != {}".format(data.keys(), sampler.keys()))
"Sampler vars can't change, names incompatible: {} != {}".format(
data.keys(), sampler.keys()
)
)
self.records_stats = True

def setup(self, draws, chain, sampler_vars=None):
Expand All @@ -146,16 +152,18 @@ def setup(self, draws, chain, sampler_vars=None):
with self.activate_file:
for varname, shape in self.var_shapes.items():
if varname not in self.samples:
self.samples.create_dataset(name=varname, shape=(draws, ) + shape,
dtype=self.var_dtypes[varname],
maxshape=(None, ) + shape)
self.samples.create_dataset(
name=varname,
shape=(draws,) + shape,
dtype=self.var_dtypes[varname],
maxshape=(None,) + shape,
)
self.draw_idx = len(self)
self.draws = self.draw_idx + draws
self._set_sampler_vars(sampler_vars)
self._is_base_setup = True
self._resize(self.draws)


def close(self):
with self.activate_file:
if self.draw_idx == self.draws:
Expand Down Expand Up @@ -190,8 +198,9 @@ def _slice(self, idx):
start, stop, step = idx.indices(len(self))
sliced = ndarray.NDArray(model=self.model, vars=self.vars)
sliced.chain = self.chain
sliced.samples = {v: self.samples[v][start:stop:step]
for v in self.varnames}
sliced.samples = {
v: self.samples[v][start:stop:step] for v in self.varnames
}
sliced.draw_idx = (stop - start) // step
return sliced

Expand Down
Loading