Skip to content

Commit 8e2f1ed

Browse files
OriolAbrilbrandonwillard
authored andcommitted
add workaround for data groups until next arviz release
1 parent 8143878 commit 8e2f1ed

File tree

2 files changed

+55
-31
lines changed

2 files changed

+55
-31
lines changed

pymc3/backends/arviz.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from aesara.graph.basic import Constant
2121
from aesara.tensor.sharedvar import SharedVariable
2222
from arviz import InferenceData, concat, rcParams
23-
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
23+
from arviz.data.base import CoordSpec, DimSpec
24+
from arviz.data.base import dict_to_dataset as _dict_to_dataset
25+
from arviz.data.base import generate_dims_coords, make_attrs, requires
2426

2527
import pymc3
2628

@@ -98,6 +100,37 @@ def insert(self, k: str, v, idx: int):
98100
self.trace_dict[k][idx, :] = v
99101

100102

103+
def dict_to_dataset(
104+
data,
105+
library=None,
106+
coords=None,
107+
dims=None,
108+
attrs=None,
109+
default_dims=None,
110+
skip_event_dims=None,
111+
index_origin=None,
112+
):
113+
"""Temporal workaround for dict_to_dataset.
114+
115+
Once ArviZ>0.11.2 release is available, only two changes are needed for everything to work.
116+
1) this should be deleted, 2) dict_to_dataset should be imported as is from arviz, no underscore,
117+
also remove unnecessary imports
118+
"""
119+
if default_dims is None:
120+
return _dict_to_dataset(
121+
data, library=library, coords=coords, dims=dims, skip_event_dims=skip_event_dims
122+
)
123+
else:
124+
out_data = {}
125+
for name, vals in data.items():
126+
vals = np.atleast_1d(vals)
127+
val_dims = dims.get(name)
128+
val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
129+
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
130+
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
131+
return xr.Dataset(data_vars=out_data, attrs=make_attrs(library=library))
132+
133+
101134
class InferenceDataConverter: # pylint: disable=too-many-instance-attributes
102135
"""Encapsulate InferenceData specific logic."""
103136

@@ -196,14 +229,13 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
196229
self.dims = {**model_dims, **self.dims}
197230

198231
self.density_dist_obs = density_dist_obs
199-
self.observations, self.multi_observations = self.find_observations()
232+
self.observations = self.find_observations()
200233

201-
def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str, Var]]]:
234+
def find_observations(self) -> Optional[Dict[str, Var]]:
202235
"""If there are observations available, return them as a dictionary."""
203236
if self.model is None:
204-
return (None, None)
237+
return None
205238
observations = {}
206-
multi_observations = {}
207239
for obs in self.model.observed_RVs:
208240
aux_obs = getattr(obs.tag, "observations", None)
209241
if aux_obs is not None:
@@ -215,7 +247,7 @@ def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str
215247
else:
216248
warnings.warn(f"No data for observation {obs}")
217249

218-
return observations, multi_observations
250+
return observations
219251

220252
def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
221253
"""Split MultiTrace object into posterior and warmup.
@@ -302,15 +334,15 @@ def posterior_to_xarray(self):
302334
coords=self.coords,
303335
dims=self.dims,
304336
attrs=self.attrs,
305-
# index_origin=self.index_origin,
337+
index_origin=self.index_origin,
306338
),
307339
dict_to_dataset(
308340
data_warmup,
309341
library=pymc3,
310342
coords=self.coords,
311343
dims=self.dims,
312344
attrs=self.attrs,
313-
# index_origin=self.index_origin,
345+
index_origin=self.index_origin,
314346
),
315347
)
316348

@@ -344,15 +376,15 @@ def sample_stats_to_xarray(self):
344376
dims=None,
345377
coords=self.coords,
346378
attrs=self.attrs,
347-
# index_origin=self.index_origin,
379+
index_origin=self.index_origin,
348380
),
349381
dict_to_dataset(
350382
data_warmup,
351383
library=pymc3,
352384
dims=None,
353385
coords=self.coords,
354386
attrs=self.attrs,
355-
# index_origin=self.index_origin,
387+
index_origin=self.index_origin,
356388
),
357389
)
358390

@@ -385,15 +417,15 @@ def log_likelihood_to_xarray(self):
385417
dims=self.dims,
386418
coords=self.coords,
387419
skip_event_dims=True,
388-
# index_origin=self.index_origin,
420+
index_origin=self.index_origin,
389421
),
390422
dict_to_dataset(
391423
data_warmup,
392424
library=pymc3,
393425
dims=self.dims,
394426
coords=self.coords,
395427
skip_event_dims=True,
396-
# index_origin=self.index_origin,
428+
index_origin=self.index_origin,
397429
),
398430
)
399431

@@ -415,11 +447,7 @@ def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
415447
k,
416448
)
417449
return dict_to_dataset(
418-
data,
419-
library=pymc3,
420-
coords=self.coords,
421-
# dims=self.dims,
422-
# index_origin=self.index_origin
450+
data, library=pymc3, coords=self.coords, dims=self.dims, index_origin=self.index_origin
423451
)
424452

425453
@requires(["posterior_predictive"])
@@ -454,25 +482,25 @@ def priors_to_xarray(self):
454482
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
455483
library=pymc3,
456484
coords=self.coords,
457-
# dims=self.dims,
458-
# index_origin=self.index_origin,
485+
dims=self.dims,
486+
index_origin=self.index_origin,
459487
)
460488
)
461489
return priors_dict
462490

463-
@requires(["observations", "multi_observations"])
491+
@requires("observations")
464492
@requires("model")
465493
def observed_data_to_xarray(self):
466494
"""Convert observed data to xarray."""
467495
if self.predictions:
468496
return None
469497
return dict_to_dataset(
470-
{**self.observations, **self.multi_observations},
498+
self.observations,
471499
library=pymc3,
472500
coords=self.coords,
473-
# dims=self.dims,
474-
# default_dims=[],
475-
# index_origin=self.index_origin,
501+
dims=self.dims,
502+
default_dims=[],
503+
index_origin=self.index_origin,
476504
)
477505

478506
@requires(["trace", "predictions"])
@@ -517,9 +545,9 @@ def is_data(name, var) -> bool:
517545
constant_data,
518546
library=pymc3,
519547
coords=self.coords,
520-
# dims=self.dims,
521-
# default_dims=[],
522-
# index_origin=self.index_origin,
548+
dims=self.dims,
549+
default_dims=[],
550+
index_origin=self.index_origin,
523551
)
524552

525553
def to_inference_data(self):

pymc3/tests/test_idata_conversion.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,6 @@ def test_multivariate_observations(self):
570570

571571

572572
class TestPyMC3WarmupHandling:
573-
@pytest.mark.skipif(
574-
not hasattr(pm.backends.base.SamplerReport, "n_draws"),
575-
reason="requires pymc3 3.9 or higher",
576-
)
577573
@pytest.mark.parametrize("save_warmup", [False, True])
578574
@pytest.mark.parametrize("chains", [1, 2])
579575
@pytest.mark.parametrize("tune,draws", [(0, 50), (10, 40), (30, 0)])

0 commit comments

Comments
 (0)