Skip to content

Commit b3be15f

Browse files
mbjosephMax Josephtwiecki
authored
Return posterior predictive samples from all chains in ModelBuilder (#140)
* Return posterior predictive samples from all chains This fixes a bug where only values from one chain were returned. It also refactors the prediction logic to reduce duplication, and makes the output of predict_posterior() consistent in type and shape with the output of pymc.sample_posterior_predictive(). * keep attributes even when computing posterior means * Add/test combined arg, revert method order * Fix import order. --------- Co-authored-by: Max Joseph <[email protected]> Co-authored-by: Thomas Wiecki <[email protected]>
1 parent e38da06 commit b3be15f

File tree

2 files changed

+28
-74
lines changed

2 files changed

+28
-74
lines changed

pymc_experimental/model_builder.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
import pandas as pd
2525
import pymc as pm
26+
import xarray as xr
2627
from pymc.util import RandomState
2728

2829

@@ -286,7 +287,7 @@ def predict(
286287
self,
287288
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
288289
extend_idata: bool = True,
289-
) -> dict:
290+
) -> xr.Dataset:
290291
"""
291292
Uses model to predict on unseen data and return point prediction of all the samples
292293
@@ -299,7 +300,7 @@ def predict(
299300
300301
Returns
301302
-------
302-
returns dictionary of sample's mean of posterior predict.
303+
returns posterior mean of predictive samples
303304
304305
Examples
305306
--------
@@ -308,42 +309,33 @@ def predict(
308309
>>> idata = model.fit(data)
309310
>>> x_pred = []
310311
>>> prediction_data = pd.DataFrame({'input':x_pred})
311-
# point predict
312312
>>> pred_mean = model.predict(prediction_data)
313313
"""
314-
315-
if data_prediction is not None: # set new input data
316-
self._data_setter(data_prediction)
317-
318-
with self.model: # sample with new input data
319-
post_pred = pm.sample_posterior_predictive(self.idata)
320-
if extend_idata:
321-
self.idata.extend(post_pred)
322-
# reshape output
323-
post_pred = self._extract_samples(post_pred)
324-
for key in post_pred:
325-
post_pred[key] = post_pred[key].mean(axis=0)
326-
327-
return post_pred
314+
posterior_predictive_samples = self.predict_posterior(data_prediction, extend_idata)
315+
posterior_means = posterior_predictive_samples.mean(dim=["chain", "draw"], keep_attrs=True)
316+
return posterior_means
328317

329318
def predict_posterior(
330319
self,
331320
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
332321
extend_idata: bool = True,
333-
) -> Dict[str, np.array]:
322+
combined: bool = False,
323+
) -> xr.Dataset:
334324
"""
335-
Uses model to predict samples on unseen data.
325+
Generate posterior predictive samples on unseen data.
336326
337327
Parameters
338328
---------
339329
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
340330
It is the data we need to make prediction on using the model.
341331
extend_idata : Boolean determining whether the predictions should be added to inference data object.
342332
Defaults to True.
333+
combined: Combine chain and draw dims into sample. Won’t work if a dim named sample already exists.
334+
Defaults to False.
343335
344336
Returns
345337
-------
346-
returns dictionary of sample's posterior predict.
338+
returns posterior predictive samples
347339
348340
Examples
349341
--------
@@ -352,8 +344,7 @@ def predict_posterior(
352344
>>> idata = model.fit(data)
353345
>>> x_pred = []
354346
>>> prediction_data = pd.DataFrame({'input': x_pred})
355-
# samples
356-
>>> pred_mean = model.predict_posterior(prediction_data)
347+
>>> pred_samples = model.predict_posterior(prediction_data)
357348
"""
358349

359350
if data_prediction is not None: # set new input data
@@ -364,30 +355,11 @@ def predict_posterior(
364355
if extend_idata:
365356
self.idata.extend(post_pred)
366357

367-
# reshape output
368-
post_pred = self._extract_samples(post_pred)
369-
370-
return post_pred
371-
372-
@staticmethod
373-
def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[str, np.array]:
374-
"""
375-
This method can be used to extract samples from posterior predict.
376-
377-
Parameters
378-
----------
379-
post_pred: arviz InferenceData object
380-
381-
Returns
382-
-------
383-
Dictionary of numpy arrays from InferenceData object
384-
"""
385-
386-
post_pred_dict = dict()
387-
for key in post_pred.posterior_predictive:
388-
post_pred_dict[key] = post_pred.posterior_predictive[key].to_numpy()[0]
358+
posterior_predictive_samples = az.extract(
359+
post_pred, "posterior_predictive", combined=combined
360+
)
389361

390-
return post_pred_dict
362+
return posterior_predictive_samples
391363

392364
@property
393365
@abstractmethod

pymc_experimental/tests/test_model_builder.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -158,40 +158,22 @@ def test_predict():
158158
prediction_data = pd.DataFrame({"input": x_pred})
159159
pred = model.predict(prediction_data)
160160
assert "y_model" in pred
161-
assert isinstance(pred, dict)
162161
assert len(prediction_data.input.values) == len(pred["y_model"])
163-
assert isinstance(pred["y_model"][0], (np.float32, np.float64))
162+
assert np.issubdtype(pred["y_model"].dtype, np.floating)
164163

165164

166-
def test_predict_posterior():
165+
@pytest.mark.parametrize("combined", [True, False])
166+
def test_predict_posterior(combined):
167167
model = test_ModelBuilder.initial_build_and_fit()
168-
x_pred = np.random.uniform(low=0, high=1, size=100)
168+
n_pred = 100
169+
x_pred = np.random.uniform(low=0, high=1, size=n_pred)
169170
prediction_data = pd.DataFrame({"input": x_pred})
170-
pred = model.predict_posterior(prediction_data)
171-
assert "y_model" in pred
172-
assert isinstance(pred, dict)
173-
assert len(prediction_data.input.values) == len(pred["y_model"][0])
174-
assert isinstance(pred["y_model"][0], np.ndarray)
175-
176-
177-
def test_extract_samples():
178-
# create a fake InferenceData object
179-
with pm.Model() as model:
180-
x = pm.Normal("x", mu=0, sigma=1)
181-
intercept = pm.Normal("intercept", mu=0, sigma=1)
182-
y_model = pm.Normal("y_model", mu=x * intercept, sigma=1, observed=[0, 1, 2])
183-
184-
idata = pm.sample(1000, tune=1000)
185-
post_pred = pm.sample_posterior_predictive(idata)
186-
187-
# call the function and get the output
188-
samples_dict = test_ModelBuilder._extract_samples(post_pred)
189-
190-
# assert that the keys and values are correct
191-
assert len(samples_dict) == len(post_pred.posterior_predictive)
192-
for key in post_pred.posterior_predictive:
193-
expected_value = post_pred.posterior_predictive[key].to_numpy()[0]
194-
assert np.array_equal(samples_dict[key], expected_value)
171+
pred = model.predict_posterior(prediction_data, combined=combined)
172+
chains = model.idata.sample_stats.dims["chain"]
173+
draws = model.idata.sample_stats.dims["draw"]
174+
expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
175+
assert pred["y_model"].shape == expected_shape
176+
assert np.issubdtype(pred["y_model"].dtype, np.floating)
195177

196178

197179
def test_id():

0 commit comments

Comments
 (0)