Skip to content

Commit e94b47e

Browse files
apply black via pre-commit
1 parent 0f21fca commit e94b47e

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

pymc3/tests/test_memo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def test_hashing_of_rv_tuples():
5252
mu = pm.Normal("mu", 0, 1)
5353
sd = pm.Gamma("sd", 1, 2)
5454
dd = pm.DensityDist(
55-
"dd", pm.Normal.dist(mu, sd).logp, random=pm.Normal.dist(mu, sd).random, observed=obs,
55+
"dd",
56+
pm.Normal.dist(mu, sd).logp,
57+
random=pm.Normal.dist(mu, sd).random,
58+
observed=obs,
5659
)
5760
print()
5861
for freerv in [mu, sd, dd] + pmodel.free_RVs:

pymc3/tests/test_sampling.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,22 @@ def test_sample(self):
7272
for cores in test_cores:
7373
for steps in [1, 10, 300]:
7474
pm.sample(
75-
steps, tune=0, step=self.step, cores=cores, random_seed=self.random_seed,
75+
steps,
76+
tune=0,
77+
step=self.step,
78+
cores=cores,
79+
random_seed=self.random_seed,
7680
)
7781

7882
def test_sample_init(self):
7983
with self.model:
8084
for init in ("advi", "advi_map", "map"):
8185
pm.sample(
82-
init=init, tune=0, n_init=1000, draws=50, random_seed=self.random_seed,
86+
init=init,
87+
tune=0,
88+
n_init=1000,
89+
draws=50,
90+
random_seed=self.random_seed,
8391
)
8492

8593
def test_sample_args(self):
@@ -99,7 +107,11 @@ def test_sample_args(self):
99107
def test_iter_sample(self):
100108
with self.model:
101109
samps = pm.sampling.iter_sample(
102-
draws=5, step=self.step, start=self.start, tune=0, random_seed=self.random_seed,
110+
draws=5,
111+
step=self.step,
112+
start=self.start,
113+
tune=0,
114+
random_seed=self.random_seed,
103115
)
104116
for i, trace in enumerate(samps):
105117
assert i == len(trace) - 1, "Trace does not have correct length."

pymc3/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tupl
289289
dataset = data.posterior
290290
else:
291291
raise ValueError(
292-
"Argument must be xarray Dataset or arviz InferenceData. Got %s", data.__class__,
292+
"Argument must be xarray Dataset or arviz InferenceData. Got %s",
293+
data.__class__,
293294
)
294295

295296
coords = dataset.coords

0 commit comments

Comments
 (0)