|
30 | 30 | import numpy as np
|
31 | 31 | import theano.gradient as tg
|
32 | 32 | from theano.tensor import Tensor
|
| 33 | +import xarray |
33 | 34 |
|
34 | 35 | from .backends.base import BaseTrace, MultiTrace
|
35 | 36 | from .backends.ndarray import NDArray
|
@@ -1520,9 +1521,9 @@ def sample_posterior_predictive(
|
1520 | 1521 |
|
1521 | 1522 | Parameters
|
1522 | 1523 | ----------
|
1523 |
| - trace: backend, list, or MultiTrace |
1524 |
| - Trace generated from MCMC sampling. Or a list containing dicts from |
1525 |
| - find_MAP() or points |
| 1524 | + trace: backend, list, xarray.Dataset, or MultiTrace |
| 1525 | + Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()), |
| 1526 | + or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior) |
1526 | 1527 | samples: int
|
1527 | 1528 | Number of posterior predictive samples to generate. Defaults to one posterior predictive
|
1528 | 1529 | sample per posterior sample, that is, the number of draws times the number of chains. It
|
@@ -1556,6 +1557,23 @@ def sample_posterior_predictive(
|
1556 | 1557 | Dictionary with the variable names as keys, and values numpy arrays containing
|
1557 | 1558 | posterior predictive samples.
|
1558 | 1559 | """
|
| 1560 | + if isinstance(trace, xarray.Dataset): |
| 1561 | + # grab posterior samples for each variable |
| 1562 | + _samples = { |
| 1563 | + vn : trace[vn].values |
| 1564 | + for vn in trace.keys() |
| 1565 | + } |
| 1566 | + # make dicts |
| 1567 | + points = [] |
| 1568 | + for c in trace.chain: |
| 1569 | + for d in trace.draw: |
| 1570 | + points.append({ |
| 1571 | + vn : s[c, d] |
| 1572 | + for vn, s in _samples.items() |
| 1573 | + }) |
| 1574 | + # use the list of points |
| 1575 | + trace = points |
| 1576 | + |
1559 | 1577 | len_trace = len(trace)
|
1560 | 1578 | try:
|
1561 | 1579 | nchain = trace.nchains
|
|
0 commit comments