Skip to content

Commit 796c9bb

Browse files
support xarray.Dataset in sample_posterior_predictive
+ closes pymc-devs#3828
1 parent 6371fcc commit 796c9bb

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
99
- Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
1010
- Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models.
11+
- `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846))
1112
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))
1213

1314
### Maintenance

pymc3/sampling.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import numpy as np
3131
import theano.gradient as tg
3232
from theano.tensor import Tensor
33+
import xarray
3334

3435
from .backends.base import BaseTrace, MultiTrace
3536
from .backends.ndarray import NDArray
@@ -1520,9 +1521,9 @@ def sample_posterior_predictive(
15201521
15211522
Parameters
15221523
----------
1523-
trace: backend, list, or MultiTrace
1524-
Trace generated from MCMC sampling. Or a list containing dicts from
1525-
find_MAP() or points
1524+
trace: backend, list, xarray.Dataset, or MultiTrace
1525+
Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
1526+
or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
15261527
samples: int
15271528
Number of posterior predictive samples to generate. Defaults to one posterior predictive
15281529
sample per posterior sample, that is, the number of draws times the number of chains. It
@@ -1556,6 +1557,23 @@ def sample_posterior_predictive(
15561557
Dictionary with the variable names as keys, and values numpy arrays containing
15571558
posterior predictive samples.
15581559
"""
1560+
if isinstance(trace, xarray.Dataset):
1561+
# grab posterior samples for each variable
1562+
_samples = {
1563+
vn : trace[vn].values
1564+
for vn in trace.keys()
1565+
}
1566+
# make dicts
1567+
points = []
1568+
for c in trace.chain:
1569+
for d in trace.draw:
1570+
points.append({
1571+
vn : s[c, d]
1572+
for vn, s in _samples.items()
1573+
})
1574+
# use the list of points
1575+
trace = points
1576+
15591577
len_trace = len(trace)
15601578
try:
15611579
nchain = trace.nchains

0 commit comments

Comments
 (0)