Skip to content

Commit 3036da5

Browse files
Support xarray input to sample_posterior_predictive (#3846)
* add test for xarray input to sample_posterior_predictive * support for xarray.Dataset as trace argument to sample_posterior_predictive and fast_sample_posterior_predictive. * closes #3828 Co-authored-by: Michael Osthege <[email protected]> Co-authored-by: rpgoldman <[email protected]>
1 parent 29821a5 commit 3036da5

File tree

5 files changed

+69
-6
lines changed

5 files changed

+69
-6
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
99
- Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
1010
- Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models.
11+
- `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846))
1112
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))
1213

1314
### Maintenance

pymc3/distributions/posterior_predictive.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
import numpy as np
1313
import theano
1414
import theano.tensor as tt
15+
from xarray import Dataset
1516

1617
from ..backends.base import MultiTrace #, TraceLike, TraceDict
1718
from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc
1819
from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext
1920
from ..exceptions import IncorrectArgumentsError
2021
from ..vartypes import theano_constant
22+
from ..util import dataset_to_point_dict
2123
# Failing tests:
2224
# test_mixture_random_shape::test_mixture_random_shape
2325
#
@@ -119,7 +121,7 @@ def __getitem__(self, item):
119121

120122

121123

122-
def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.ndarray]]],
124+
def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]],
123125
samples: Optional[int]=None,
124126
model: Optional[Model]=None,
125127
var_names: Optional[List[str]]=None,
@@ -135,7 +137,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
135137
136138
Parameters
137139
----------
138-
trace : MultiTrace or List of points
140+
trace : MultiTrace, xarray.Dataset, or List of points (dictionary)
139141
Trace generated from MCMC sampling.
140142
samples : int, optional
141143
Number of posterior predictive samples to generate. Defaults to one posterior predictive
@@ -168,6 +170,9 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
168170
### greater than the number of samples in the trace parameter, we sample repeatedly. This
169171
### makes the shape issues just a little easier to deal with.
170172

173+
if isinstance(trace, Dataset):
174+
trace = dataset_to_point_dict(trace)
175+
171176
model = modelcontext(model)
172177
assert model is not None
173178
with model:

pymc3/sampling.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import numpy as np
3131
import theano.gradient as tg
3232
from theano.tensor import Tensor
33+
import xarray
3334

3435
from .backends.base import BaseTrace, MultiTrace
3536
from .backends.ndarray import NDArray
@@ -53,6 +54,7 @@
5354
get_untransformed_name,
5455
is_transformed_name,
5556
get_default_varnames,
57+
dataset_to_point_dict,
5658
)
5759
from .vartypes import discrete_types
5860
from .exceptions import IncorrectArgumentsError
@@ -1520,9 +1522,9 @@ def sample_posterior_predictive(
15201522
15211523
Parameters
15221524
----------
1523-
trace: backend, list, or MultiTrace
1524-
Trace generated from MCMC sampling. Or a list containing dicts from
1525-
find_MAP() or points
1525+
trace: backend, list, xarray.Dataset, or MultiTrace
1526+
Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
1527+
or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
15261528
samples: int
15271529
Number of posterior predictive samples to generate. Defaults to one posterior predictive
15281530
sample per posterior sample, that is, the number of draws times the number of chains. It
@@ -1556,6 +1558,9 @@ def sample_posterior_predictive(
15561558
Dictionary with the variable names as keys, and values numpy arrays containing
15571559
posterior predictive samples.
15581560
"""
1561+
if isinstance(trace, xarray.Dataset):
1562+
trace = dataset_to_point_dict(trace)
1563+
15591564
len_trace = len(trace)
15601565
try:
15611566
nchain = trace.nchains

pymc3/tests/test_sampling.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import mock
2323

2424
import numpy.testing as npt
25+
import arviz as az
2526
import pymc3 as pm
2627
import theano.tensor as tt
2728
from theano import shared
@@ -880,3 +881,32 @@ def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
880881
var_names=['d']
881882
)
882883

884+
def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):
885+
pmodel, trace = point_list_arg_bug_fixture
886+
887+
with pmodel:
888+
prior = pm.sample_prior_predictive(samples=20)
889+
idat = az.from_pymc3(trace, prior=prior)
890+
with pmodel:
891+
pp = pm.sample_posterior_predictive(
892+
idat.prior,
893+
var_names=['d']
894+
)
895+
896+
def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
897+
pmodel, trace = point_list_arg_bug_fixture
898+
idat = az.from_pymc3(trace)
899+
with pmodel:
900+
pp = pm.sample_posterior_predictive(
901+
idat.posterior,
902+
var_names=['d']
903+
)
904+
905+
def test_sample_from_xarray_posterior_fast(self, point_list_arg_bug_fixture):
906+
pmodel, trace = point_list_arg_bug_fixture
907+
idat = az.from_pymc3(trace)
908+
with pmodel:
909+
pp = pm.fast_sample_posterior_predictive(
910+
idat.posterior,
911+
var_names=['d']
912+
)

pymc3/util.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414

1515
import re
1616
import functools
17-
from numpy import asscalar
17+
from typing import List, Dict
18+
19+
import xarray
20+
from numpy import asscalar, ndarray
21+
1822

1923
LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE)
2024

@@ -179,3 +183,21 @@ def enhanced(*args, **kwargs):
179183
newwrapper = functools.partial(wrapper, *args, **kwargs)
180184
return newwrapper
181185
return enhanced
186+
187+
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
188+
# grab posterior samples for each variable
189+
_samples = {
190+
vn : ds[vn].values
191+
for vn in ds.keys()
192+
}
193+
# make dicts
194+
points = []
195+
for c in ds.chain:
196+
for d in ds.draw:
197+
points.append({
198+
vn : s[c, d]
199+
for vn, s in _samples.items()
200+
})
201+
# use the list of points
202+
ds = points
203+
return ds

0 commit comments

Comments
 (0)