Skip to content

Commit 02897d1

Browse files
Fix CI (#5683)
Co-authored-by: Michael Osthege <[email protected]>
1 parent 741d455 commit 02897d1

File tree

5 files changed

+29
-25
lines changed

5 files changed

+29
-25
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
exclude: ^requirements-dev\.txt$
1515
- id: trailing-whitespace
1616
- repo: https://github.com/pre-commit/mirrors-mypy
17-
rev: v0.941
17+
rev: v0.942
1818
hooks:
1919
- id: mypy
2020
name: Run static type checks
@@ -42,11 +42,11 @@ repos:
4242
- id: pyupgrade
4343
args: [--py37-plus]
4444
- repo: https://github.com/psf/black
45-
rev: 22.1.0
45+
rev: 22.3.0
4646
hooks:
4747
- id: black
4848
- repo: https://github.com/PyCQA/pylint
49-
rev: v2.12.2
49+
rev: v2.13.2
5050
hooks:
5151
- id: pylint
5252
args: [--rcfile=.pylintrc]

pymc/backends/arviz.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Iterable,
1010
Mapping,
1111
Optional,
12+
Sequence,
1213
Tuple,
1314
Union,
1415
)
@@ -178,7 +179,10 @@ def __init__(
178179
" one of trace, prior, posterior_predictive or predictions."
179180
)
180181

181-
untyped_coords = {**self.model.coords, **(coords or {})}
182+
# Make coord types more rigid
183+
untyped_coords: Dict[str, Optional[Sequence[Any]]] = {**self.model.coords}
184+
if coords:
185+
untyped_coords.update(coords)
182186
self.coords = {
183187
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
184188
for cname, cvals in untyped_coords.items()
@@ -649,8 +653,8 @@ def predictions_to_inference_data(
649653
)
650654
if hasattr(idata_orig, "posterior"):
651655
assert idata_orig is not None
652-
converter.nchains = idata_orig.posterior.dims["chain"]
653-
converter.ndraws = idata_orig.posterior.dims["draw"]
656+
converter.nchains = idata_orig["posterior"].dims["chain"]
657+
converter.ndraws = idata_orig["posterior"].dims["draw"]
654658
else:
655659
aelem = next(iter(predictions.values()))
656660
converter.nchains, converter.ndraws = aelem.shape[:2]

pymc/backends/report.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
115115
self._add_warnings([warn])
116116
return
117117

118-
if idata.posterior.sizes["chain"] == 1:
118+
if idata["posterior"].sizes["chain"] == 1:
119119
msg = (
120120
"Only one chain was sampled, this makes it impossible to "
121121
"run some convergence checks"
@@ -124,7 +124,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
124124
self._add_warnings([warn])
125125
return
126126

127-
elif idata.posterior.sizes["chain"] < 4:
127+
elif idata["posterior"].sizes["chain"] < 4:
128128
msg = (
129129
"We recommend running at least 4 chains for robust computation of "
130130
"convergence diagnostics"
@@ -140,7 +140,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
140140
if is_transformed_name(rv_name):
141141
rv_name2 = get_untransformed_name(rv_name)
142142
rv_name = rv_name2 if rv_name2 in valid_name else rv_name
143-
if rv_name in idata.posterior:
143+
if rv_name in idata["posterior"]:
144144
varnames.append(rv_name)
145145

146146
self._ess = ess = arviz.ess(idata, var_names=varnames)
@@ -158,7 +158,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
158158
warnings.append(warn)
159159

160160
eff_min = min(val.min() for val in ess.values())
161-
eff_per_chain = eff_min / idata.posterior.sizes["chain"]
161+
eff_per_chain = eff_min / idata["posterior"].sizes["chain"]
162162
if eff_per_chain < 100:
163163
msg = (
164164
"The effective sample size per chain is smaller than 100 for some parameters. "

pymc/sampling.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ def sample(
614614
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
615615
f"took {mtrace.report.t_sampling:.0f} seconds."
616616
)
617+
mtrace.report._log_summary()
617618

618619
idata = None
619620
if compute_convergence_checks or return_inferencedata:
@@ -622,19 +623,18 @@ def sample(
622623
ikwargs.update(idata_kwargs)
623624
idata = pm.to_inference_data(mtrace, **ikwargs)
624625

625-
if compute_convergence_checks:
626-
if draws - tune < 100:
627-
warnings.warn(
628-
"The number of samples is too small to check convergence reliably.", stacklevel=2
629-
)
630-
else:
631-
mtrace.report._run_convergence_checks(idata, model)
632-
mtrace.report._log_summary()
626+
if compute_convergence_checks:
627+
if draws - tune < 100:
628+
warnings.warn(
629+
"The number of samples is too small to check convergence reliably.",
630+
stacklevel=2,
631+
)
632+
else:
633+
mtrace.report._run_convergence_checks(idata, model)
633634

634-
if return_inferencedata:
635-
return idata
636-
else:
637-
return mtrace
635+
if return_inferencedata:
636+
return idata
637+
return mtrace
638638

639639

640640
def _check_start_shape(model, start: PointType):
@@ -1621,7 +1621,7 @@ def sample_posterior_predictive(
16211621
_trace: Union[MultiTrace, PointList]
16221622
nchain: int
16231623
if isinstance(trace, InferenceData):
1624-
_trace = dataset_to_point_list(trace.posterior)
1624+
_trace = dataset_to_point_list(trace["posterior"])
16251625
nchain, len_trace = chains_and_samples(trace)
16261626
elif isinstance(trace, xarray.Dataset):
16271627
_trace = dataset_to_point_list(trace)
@@ -1704,7 +1704,7 @@ def sample_posterior_predictive(
17041704

17051705
if not vars_to_sample:
17061706
if return_inferencedata and not extend_inferencedata:
1707-
return None
1707+
return InferenceData()
17081708
elif return_inferencedata and extend_inferencedata:
17091709
return trace
17101710
return {}

pymc/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tupl
248248
if isinstance(data, xarray.Dataset):
249249
dataset = data
250250
elif isinstance(data, arviz.InferenceData):
251-
dataset = data.posterior
251+
dataset = data["posterior"]
252252
else:
253253
raise ValueError(
254254
"Argument must be xarray Dataset or arviz InferenceData. Got %s",

0 commit comments

Comments
 (0)