Skip to content

Commit 2e46eb2

Browse files
apply black
1 parent 7007aab commit 2e46eb2

File tree

3 files changed

+10
-23
lines changed

3 files changed

+10
-23
lines changed

pymc3/tests/test_memo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
def test_memo():
2121
def fun(inputs, suffix="_a"):
2222
return str(inputs) + str(suffix)
23+
2324
inputs = ["i1", "i2"]
2425
assert fun(inputs) == "['i1', 'i2']_a"
2526
assert fun(inputs, "_b") == "['i1', 'i2']_b"
@@ -28,7 +29,9 @@ def fun(inputs, suffix="_a"):
2829
assert hasattr(fun, "cache")
2930
assert isinstance(fun.cache, dict)
3031
assert len(fun.cache) == 0
31-
32+
33+
# call the memoized function with a list input
34+
# and check the size of the cache!
3235
assert funmem(inputs) == "['i1', 'i2']_a"
3336
assert funmem(inputs) == "['i1', 'i2']_a"
3437
assert len(fun.cache) == 1
@@ -49,10 +52,7 @@ def test_hashing_of_rv_tuples():
4952
mu = pm.Normal("mu", 0, 1)
5053
sd = pm.Gamma("sd", 1, 2)
5154
dd = pm.DensityDist(
52-
"dd",
53-
pm.Normal.dist(mu, sd).logp,
54-
random=pm.Normal.dist(mu, sd).random,
55-
observed=obs,
55+
"dd", pm.Normal.dist(mu, sd).logp, random=pm.Normal.dist(mu, sd).random, observed=obs,
5656
)
5757
print()
5858
for freerv in [mu, sd, dd] + pmodel.free_RVs:

pymc3/tests/test_sampling.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,14 @@ def test_sample(self):
7373
for cores in test_cores:
7474
for steps in [1, 10, 300]:
7575
pm.sample(
76-
steps,
77-
tune=0,
78-
step=self.step,
79-
cores=cores,
80-
random_seed=self.random_seed,
76+
steps, tune=0, step=self.step, cores=cores, random_seed=self.random_seed,
8177
)
8278

8379
def test_sample_init(self):
8480
with self.model:
8581
for init in ("advi", "advi_map", "map"):
8682
pm.sample(
87-
init=init,
88-
tune=0,
89-
n_init=1000,
90-
draws=50,
91-
random_seed=self.random_seed,
83+
init=init, tune=0, n_init=1000, draws=50, random_seed=self.random_seed,
9284
)
9385

9486
def test_sample_args(self):
@@ -108,11 +100,7 @@ def test_sample_args(self):
108100
def test_iter_sample(self):
109101
with self.model:
110102
samps = pm.sampling.iter_sample(
111-
draws=5,
112-
step=self.step,
113-
start=self.start,
114-
tune=0,
115-
random_seed=self.random_seed,
103+
draws=5, step=self.step, start=self.start, tune=0, random_seed=self.random_seed,
116104
)
117105
for i, trace in enumerate(samps):
118106
assert i == len(trace) - 1, "Trace does not have correct length."

pymc3/util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def enhanced(*args, **kwargs):
261261
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
262262
warnings.warn(
263263
"dataset_to_point_dict was renamed to dataset_to_point_list and will be removed!.",
264-
DeprecationWarning
264+
DeprecationWarning,
265265
)
266266
return dataset_to_point_list(ds)
267267

@@ -289,8 +289,7 @@ def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tupl
289289
dataset = data.posterior
290290
else:
291291
raise ValueError(
292-
"Argument must be xarray Dataset or arviz InferenceData. Got %s",
293-
data.__class__,
292+
"Argument must be xarray Dataset or arviz InferenceData. Got %s", data.__class__,
294293
)
295294

296295
coords = dataset.coords

0 commit comments

Comments
 (0)