Skip to content

Commit dd13a8d

Browse files
committed
Apply pyupgrade
1 parent 9eb69fc commit dd13a8d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+132
-135
lines changed

docs/source/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
# -*- coding: utf-8 -*-
32
#
43
# pymc3 documentation build configuration file, created by
54
# sphinx-quickstart on Sat Dec 26 14:40:23 2015.

docs/source/sphinxext/gallery_generator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ def __init__(self, filename, target_dir):
7070
self.basename = os.path.basename(filename)
7171
self.stripped_name = os.path.splitext(self.basename)[0]
7272
self.output_html = os.path.join(
73-
"..", "notebooks", "{}.html".format(self.stripped_name)
73+
"..", "notebooks", f"{self.stripped_name}.html"
7474
)
7575
self.image_dir = os.path.join(target_dir, "_images")
7676
self.png_path = os.path.join(
77-
self.image_dir, "{}.png".format(self.stripped_name)
77+
self.image_dir, f"{self.stripped_name}.png"
7878
)
79-
with open(filename, "r") as fid:
79+
with open(filename) as fid:
8080
self.json_source = json.load(fid)
8181
self.pagetitle = self.extract_title()
8282
self.default_image_loc = DEFAULT_IMG_LOC
@@ -89,7 +89,7 @@ def __init__(self, filename, target_dir):
8989

9090
self.gen_previews()
9191
else:
92-
print("skipping {0}".format(filename))
92+
print(f"skipping {filename}")
9393

9494
def extract_preview_pic(self):
9595
"""By default, just uses the last image in the notebook."""
@@ -136,7 +136,7 @@ def build_gallery(srcdir, gallery):
136136
working_dir = os.getcwd()
137137
os.chdir(srcdir)
138138
static_dir = os.path.join(srcdir, "_static")
139-
target_dir = os.path.join(srcdir, "nb_{}".format(gallery))
139+
target_dir = os.path.join(srcdir, f"nb_{gallery}")
140140
image_dir = os.path.join(target_dir, "_images")
141141
source_dir = os.path.abspath(
142142
os.path.join(os.path.dirname(os.path.dirname(srcdir)), "notebooks")
@@ -182,8 +182,8 @@ def build_gallery(srcdir, gallery):
182182
"thumb": os.path.basename(default_png_path),
183183
}
184184

185-
js_file = os.path.join(image_dir, "gallery_{}_contents.js".format(gallery))
186-
with open(table_of_contents_file, "r") as toc:
185+
js_file = os.path.join(image_dir, f"gallery_{gallery}_contents.js")
186+
with open(table_of_contents_file) as toc:
187187
table_of_contents = toc.read()
188188

189189
js_contents = "Gallery.examples = {}\n{}".format(

pymc3/backends/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ def __getitem__(self, idx):
347347
return self.get_sampler_stats(var, burn=burn, thin=thin)
348348
raise KeyError("Unknown variable %s" % var)
349349

350-
_attrs = set(['_straces', 'varnames', 'chains', 'stat_names',
351-
'supports_sampler_stats', '_report'])
350+
_attrs = {'_straces', 'varnames', 'chains', 'stat_names',
351+
'supports_sampler_stats', '_report'}
352352

353353
def __getattr__(self, name):
354354
# Avoid infinite recursion when called before __init__
@@ -417,7 +417,7 @@ def add_values(self, vals, overwrite=False) -> None:
417417
self.varnames.remove(k)
418418
new_var = 0
419419
else:
420-
raise ValueError("Variable name {} already exists.".format(k))
420+
raise ValueError(f"Variable name {k} already exists.")
421421

422422
self.varnames.append(k)
423423

@@ -448,7 +448,7 @@ def remove_values(self, name):
448448
"""
449449
varnames = self.varnames
450450
if name not in varnames:
451-
raise KeyError("Unknown variable {}".format(name))
451+
raise KeyError(f"Unknown variable {name}")
452452
self.varnames.remove(name)
453453
chains = self._straces
454454
for chain in chains.values():

pymc3/backends/hdf5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def sampler_vars(self, values):
140140
data.create_dataset(varname, (self.draws,), dtype=dtype, maxshape=(None,))
141141
elif data.keys() != sampler.keys():
142142
raise ValueError(
143-
"Sampler vars can't change, names incompatible: {} != {}".format(data.keys(), sampler.keys()))
143+
f"Sampler vars can't change, names incompatible: {data.keys()} != {sampler.keys()}")
144144
self.records_stats = True
145145

146146
def setup(self, draws, chain, sampler_vars=None):

pymc3/backends/ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def load(self, model: Model) -> 'NDArray':
176176
raise TraceDirectoryError("%s is not a trace directory" % self.directory)
177177

178178
new_trace = NDArray(model=model)
179-
with open(self.metadata_path, 'r') as buff:
179+
with open(self.metadata_path) as buff:
180180
metadata = json.load(buff)
181181

182182
metadata['_stats'] = [{k: np.array(v) for k, v in stat.items()} for stat in metadata['_stats']]

pymc3/backends/sqlite.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,8 @@ def load(name, model=None):
340340
db.connect()
341341
varnames = _get_table_list(db.cursor)
342342
if len(varnames) == 0:
343-
raise ValueError(('Can not get variable list for database'
344-
'`{}`'.format(name)))
343+
raise ValueError('Can not get variable list for database'
344+
'`{}`'.format(name))
345345
chains = _get_chain_list(db.cursor, varnames[0])
346346

347347
straces = []
@@ -367,14 +367,14 @@ def _get_table_list(cursor):
367367

368368

369369
def _get_var_strs(cursor, varname):
370-
cursor.execute('SELECT * FROM [{}]'.format(varname))
370+
cursor.execute(f'SELECT * FROM [{varname}]')
371371
col_names = (col_descr[0] for col_descr in cursor.description)
372372
return [name for name in col_names if name.startswith('v')]
373373

374374

375375
def _get_chain_list(cursor, varname):
376376
"""Return a list of sorted chains for `varname`."""
377-
cursor.execute('SELECT DISTINCT chain FROM [{}]'.format(varname))
377+
cursor.execute(f'SELECT DISTINCT chain FROM [{varname}]')
378378
chains = sorted([chain[0] for chain in cursor.fetchall()])
379379
return chains
380380

pymc3/backends/text.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def setup(self, draws, chain):
9191
self._fh.close()
9292

9393
self.chain = chain
94-
self.filename = os.path.join(self.name, 'chain-{}.csv'.format(chain))
94+
self.filename = os.path.join(self.name, f'chain-{chain}.csv')
9595

9696
cnames = [fv for v in self.varnames for fv in self.flat_names[v]]
9797

@@ -201,7 +201,7 @@ def load(name, model=None):
201201
files = glob(os.path.join(name, 'chain-*.csv'))
202202

203203
if len(files) == 0:
204-
raise ValueError('No files present in directory {}'.format(name))
204+
raise ValueError(f'No files present in directory {name}')
205205

206206
straces = []
207207
for f in files:
@@ -249,7 +249,7 @@ def dump(name, trace, chains=None):
249249
chains = trace.chains
250250

251251
for chain in chains:
252-
filename = os.path.join(name, 'chain-{}.csv'.format(chain))
252+
filename = os.path.join(name, f'chain-{chain}.csv')
253253
df = ttab.trace_to_dataframe(
254254
trace, chains=chain, include_transformed=True)
255255
df.to_csv(filename, index=False)

pymc3/distributions/continuous.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def assert_negative_support(var, label, distname, value=-1e-6):
134134
support = False
135135

136136
if np.any(support):
137-
msg = "The variable specified for {0} has negative support for {1}, ".format(
137+
msg = "The variable specified for {} has negative support for {}, ".format(
138138
label, distname
139139
)
140140
msg += "likely making it unsuitable for this parameter."
@@ -294,7 +294,7 @@ def logcdf(self, value):
294294
tt.switch(
295295
tt.eq(value, self.upper),
296296
0,
297-
tt.log((value - self.lower)) - tt.log((self.upper - self.lower)),
297+
tt.log(value - self.lower) - tt.log(self.upper - self.lower),
298298
),
299299
)
300300

@@ -1887,7 +1887,7 @@ class StudentT(Continuous):
18871887

18881888
def __init__(self, nu, mu=0, lam=None, sigma=None, sd=None, *args, **kwargs):
18891889
super().__init__(*args, **kwargs)
1890-
super(StudentT, self).__init__(*args, **kwargs)
1890+
super().__init__(*args, **kwargs)
18911891
if sd is not None:
18921892
sigma = sd
18931893
warnings.warn("sd is deprecated, use sigma instead", DeprecationWarning)

pymc3/distributions/distribution.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __new__(cls, name, *args, **kwargs):
6060
"for a standalone distribution.")
6161

6262
if not isinstance(name, string_types):
63-
raise TypeError("Name needs to be a string but got: {}".format(name))
63+
raise TypeError(f"Name needs to be a string but got: {name}")
6464

6565
data = kwargs.pop('observed', None)
6666
cls.data = data
@@ -728,7 +728,7 @@ def draw_values(params, point=None, size=None):
728728
# test_distributions_random::TestDrawValues::test_draw_order fails without it
729729
# The remaining params that must be drawn are all hashable
730730
to_eval = set()
731-
missing_inputs = set([j for j, p in symbolic_params])
731+
missing_inputs = {j for j, p in symbolic_params}
732732
while to_eval or missing_inputs:
733733
if to_eval == missing_inputs:
734734
raise ValueError('Cannot resolve inputs for {}'.format([get_var_name(params[j]) for j in to_eval]))
@@ -828,7 +828,7 @@ def vectorize_theano_function(f, inputs, output):
828828
"""
829829
inputs_signatures = ",".join(
830830
[
831-
get_vectorize_signature(var, var_name="i_{}".format(input_ind))
831+
get_vectorize_signature(var, var_name=f"i_{input_ind}")
832832
for input_ind, var in enumerate(inputs)
833833
]
834834
)
@@ -846,9 +846,9 @@ def get_vectorize_signature(var, var_name="i"):
846846
return "()"
847847
else:
848848
sig = ",".join(
849-
["{}_{}".format(var_name, axis_ind) for axis_ind in range(var.ndim)]
849+
[f"{var_name}_{axis_ind}" for axis_ind in range(var.ndim)]
850850
)
851-
return "({})".format(sig)
851+
return f"({sig})"
852852

853853

854854
def _draw_value(param, point=None, givens=None, size=None):

pymc3/distributions/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(self, w, comp_dists, *args, **kwargs):
118118
isinstance(comp_dists, Distribution)
119119
or (
120120
isinstance(comp_dists, Iterable)
121-
and all((isinstance(c, Distribution) for c in comp_dists))
121+
and all(isinstance(c, Distribution) for c in comp_dists)
122122
)
123123
):
124124
raise TypeError(

pymc3/distributions/posterior_predictive.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
assert point_list is None and dict is None
8989
self.data = {} # Dict[str, np.ndarray]
9090
self._len = sum(
91-
(len(multi_trace._straces[chain]) for chain in multi_trace.chains)
91+
len(multi_trace._straces[chain]) for chain in multi_trace.chains
9292
)
9393
self.varnames = multi_trace.varnames
9494
for vn in multi_trace.varnames:
@@ -153,15 +153,15 @@ def __getitem__(self, item: Union[slice, int]) -> "_TraceDict":
153153

154154
def __getitem__(self, item):
155155
if isinstance(item, str):
156-
return super(_TraceDict, self).__getitem__(item)
156+
return super().__getitem__(item)
157157
elif isinstance(item, slice):
158158
return self._extract_slice(item)
159159
elif isinstance(item, int):
160160
return _TraceDict(
161161
dict={k: np.atleast_1d(v[item]) for k, v in self.data.items()}
162162
)
163163
elif hasattr(item, "name"):
164-
return super(_TraceDict, self).__getitem__(item.name)
164+
return super().__getitem__(item.name)
165165
else:
166166
raise IndexError("Illegal index %s for _TraceDict" % str(item))
167167

@@ -242,7 +242,7 @@ def fast_sample_posterior_predictive(
242242
"Should not specify both keep_size and samples arguments"
243243
)
244244

245-
if isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
245+
if isinstance(trace, list) and all(isinstance(x, dict) for x in trace):
246246
_trace = _TraceDict(point_list=trace)
247247
elif isinstance(trace, MultiTrace):
248248
_trace = _TraceDict(multi_trace=trace)
@@ -454,7 +454,7 @@ def draw_values(self) -> List[np.ndarray]:
454454
# test_distributions_random::TestDrawValues::test_draw_order fails without it
455455
# The remaining params that must be drawn are all hashable
456456
to_eval: Set[int] = set()
457-
missing_inputs: Set[int] = set([j for j, p in self.symbolic_params])
457+
missing_inputs: Set[int] = {j for j, p in self.symbolic_params}
458458

459459
while to_eval or missing_inputs:
460460
if to_eval == missing_inputs:

pymc3/distributions/shape_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ def _check_shape_type(shape):
5858
shape = np.atleast_1d(shape)
5959
for s in shape:
6060
if isinstance(s, np.ndarray) and s.ndim > 0:
61-
raise TypeError("Value {} is not a valid integer".format(s))
61+
raise TypeError(f"Value {s} is not a valid integer")
6262
o = int(s)
6363
if o != s:
64-
raise TypeError("Value {} is not a valid integer".format(s))
64+
raise TypeError(f"Value {s} is not a valid integer")
6565
out.append(o)
6666
except Exception:
6767
raise TypeError(
68-
"Supplied value {} does not represent a valid shape".format(shape)
68+
f"Supplied value {shape} does not represent a valid shape"
6969
)
7070
return tuple(out)
7171

@@ -103,7 +103,7 @@ def shapes_broadcasting(*args, raise_exception=False):
103103
if raise_exception:
104104
raise ValueError(
105105
"Supplied shapes {} do not broadcast together".format(
106-
", ".join(["{}".format(a) for a in args])
106+
", ".join([f"{a}" for a in args])
107107
)
108108
)
109109
else:
@@ -165,7 +165,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
165165
if broadcasted_shape is None:
166166
raise ValueError(
167167
"Cannot broadcast provided shapes {} given size: {}".format(
168-
", ".join(["{}".format(s) for s in shapes]), size
168+
", ".join([f"{s}" for s in shapes]), size
169169
)
170170
)
171171
return broadcasted_shape
@@ -181,7 +181,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
181181
except ValueError:
182182
raise ValueError(
183183
"Cannot broadcast provided shapes {} given size: {}".format(
184-
", ".join(["{}".format(s) for s in shapes]), size
184+
", ".join([f"{s}" for s in shapes]), size
185185
)
186186
)
187187
broadcastable_shapes = []

pymc3/distributions/simulator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def __init__(
4343
Distance functions. Available options are "gaussian_kernel" (default), "wasserstein",
4444
"energy" or a user defined function that takes epsilon (a scalar), and the summary
4545
statistics of observed_data, and simulated_data as input.
46-
``gaussian_kernel`` :math: `\sum \left(-0.5 \left(\frac{xo - xs}{\epsilon}\right)^2\right)`
47-
``wasserstein`` :math: `\frac{1}{n} \sum{\left(\frac{|xo - xs|}{\epsilon}\right)}`
48-
``energy`` :math: `\sqrt{2} \sqrt{\frac{1}{n} \sum \left(\frac{|xo - xs|}{\epsilon}\right)^2}`
46+
``gaussian_kernel`` :math: `\\sum \\left(-0.5 \\left(\frac{xo - xs}{\\epsilon}\right)^2\right)`
47+
``wasserstein`` :math: `\frac{1}{n} \\sum{\\left(\frac{|xo - xs|}{\\epsilon}\right)}`
48+
``energy`` :math: `\\sqrt{2} \\sqrt{\frac{1}{n} \\sum \\left(\frac{|xo - xs|}{\\epsilon}\right)^2}`
4949
For the wasserstein and energy distances the observed data xo and simulated data xs
5050
are internally sorted (i.e. the sum_stat is "sort").
5151
sum_stat: str or callable
@@ -125,7 +125,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
125125
distance = self.distance.__name__
126126

127127
if formatting == "latex":
128-
return f"$\\text{{{name}}} \sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
128+
return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
129129
else:
130130
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"
131131

pymc3/examples/samplers_mvnormal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def run(steppers, p):
5353
print('{} samples across {} chains'.format(len(mt) * mt.nchains, mt.nchains))
5454
traces[name] = mt
5555
en = pm.ess(mt)
56-
print('effective: {}\r\n'.format(en))
56+
print(f'effective: {en}\r\n')
5757
if USE_XY:
5858
effn[name] = np.mean(en['x']) / len(mt) / mt.nchains
5959
else:

pymc3/exceptions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ class ShapeError(Exception):
4545
"""Error that the shape of a variable is incorrect."""
4646
def __init__(self, message, actual=None, expected=None):
4747
if actual is not None and expected is not None:
48-
super().__init__('{} (actual {} != expected {})'.format(message, actual, expected))
48+
super().__init__(f'{message} (actual {actual} != expected {expected})')
4949
elif actual is not None and expected is None:
50-
super().__init__('{} (actual {})'.format(message, actual))
50+
super().__init__(f'{message} (actual {actual})')
5151
elif actual is None and expected is not None:
52-
super().__init__('{} (expected {})'.format(message, expected))
52+
super().__init__(f'{message} (expected {expected})')
5353
else:
5454
super().__init__(message)
5555

@@ -58,10 +58,10 @@ class DtypeError(TypeError):
5858
"""Error that the dtype of a variable is incorrect."""
5959
def __init__(self, message, actual=None, expected=None):
6060
if actual is not None and expected is not None:
61-
super().__init__('{} (actual {} != expected {})'.format(message, actual, expected))
61+
super().__init__(f'{message} (actual {actual} != expected {expected})')
6262
elif actual is not None and expected is None:
63-
super().__init__('{} (actual {})'.format(message, actual))
63+
super().__init__(f'{message} (actual {actual})')
6464
elif actual is None and expected is not None:
65-
super().__init__('{} (expected {})'.format(message, expected))
65+
super().__init__(f'{message} (expected {expected})')
6666
else:
6767
super().__init__(message)

0 commit comments

Comments
 (0)