Skip to content

Commit 662024b

Browse files
mbjosephMax Josephtwiecki
authored andcommitted
Return posterior predictive samples from all chains in ModelBuilder (pymc-devs#140)
* Return posterior predictive samples from all chains This fixes a bug where only values from one chain were returned. It also refactors the prediction logic to reduce duplication, and makes the output of predict_posterior() consistent in type and shape with the output of pymc.sample_posterior_predictive(). * keep attributes even when computing posterior means * Add/test combined arg, revert method order * Fix import order. --------- Co-authored-by: Max Joseph <[email protected]> Co-authored-by: Thomas Wiecki <[email protected]>
1 parent d5d6aa1 commit 662024b

File tree

2 files changed

+28
-74
lines changed

2 files changed

+28
-74
lines changed

pymc_experimental/model_builder.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
import pandas as pd
2525
import pymc as pm
26+
import xarray as xr
2627
from pymc.util import RandomState
2728

2829

@@ -306,7 +307,7 @@ def predict(
306307
self,
307308
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
308309
extend_idata: bool = True,
309-
) -> dict:
310+
) -> xr.Dataset:
310311
"""
311312
Uses model to predict on unseen data and return point prediction of all the samples
312313
@@ -319,7 +320,7 @@ def predict(
319320
320321
Returns
321322
-------
322-
returns dictionary of sample's mean of posterior predict.
323+
returns posterior mean of predictive samples
323324
324325
Examples
325326
--------
@@ -328,42 +329,33 @@ def predict(
328329
>>> idata = model.fit(data)
329330
>>> x_pred = []
330331
>>> prediction_data = pd.DataFrame({'input':x_pred})
331-
# point predict
332332
>>> pred_mean = model.predict(prediction_data)
333333
"""
334-
335-
if data_prediction is not None: # set new input data
336-
self._data_setter(data_prediction)
337-
338-
with self.model: # sample with new input data
339-
post_pred = pm.sample_posterior_predictive(self.idata)
340-
if extend_idata:
341-
self.idata.extend(post_pred)
342-
# reshape output
343-
post_pred = self._extract_samples(post_pred)
344-
for key in post_pred:
345-
post_pred[key] = post_pred[key].mean(axis=0)
346-
347-
return post_pred
334+
posterior_predictive_samples = self.predict_posterior(data_prediction, extend_idata)
335+
posterior_means = posterior_predictive_samples.mean(dim=["chain", "draw"], keep_attrs=True)
336+
return posterior_means
348337

349338
def predict_posterior(
350339
self,
351340
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
352341
extend_idata: bool = True,
353-
) -> Dict[str, np.array]:
342+
combined: bool = False,
343+
) -> xr.Dataset:
354344
"""
355-
Uses model to predict samples on unseen data.
345+
Generate posterior predictive samples on unseen data.
356346
357347
Parameters
358348
---------
359349
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
360350
It is the data we need to make prediction on using the model.
361351
extend_idata : Boolean determining whether the predictions should be added to inference data object.
362352
Defaults to True.
353+
combined: Combine chain and draw dims into sample. Won’t work if a dim named sample already exists.
354+
Defaults to False.
363355
364356
Returns
365357
-------
366-
returns dictionary of sample's posterior predict.
358+
returns posterior predictive samples
367359
368360
Examples
369361
--------
@@ -372,8 +364,7 @@ def predict_posterior(
372364
>>> idata = model.fit(data)
373365
>>> x_pred = []
374366
>>> prediction_data = pd.DataFrame({'input': x_pred})
375-
# samples
376-
>>> pred_mean = model.predict_posterior(prediction_data)
367+
>>> pred_samples = model.predict_posterior(prediction_data)
377368
"""
378369

379370
if data_prediction is not None: # set new input data
@@ -384,30 +375,11 @@ def predict_posterior(
384375
if extend_idata:
385376
self.idata.extend(post_pred)
386377

387-
# reshape output
388-
post_pred = self._extract_samples(post_pred)
389-
390-
return post_pred
391-
392-
@staticmethod
393-
def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[str, np.array]:
394-
"""
395-
This method can be used to extract samples from posterior predict.
396-
397-
Parameters
398-
----------
399-
post_pred: arviz InferenceData object
400-
401-
Returns
402-
-------
403-
Dictionary of numpy arrays from InferenceData object
404-
"""
405-
406-
post_pred_dict = dict()
407-
for key in post_pred.posterior_predictive:
408-
post_pred_dict[key] = post_pred.posterior_predictive[key].to_numpy()[0]
378+
posterior_predictive_samples = az.extract(
379+
post_pred, "posterior_predictive", combined=combined
380+
)
409381

410-
return post_pred_dict
382+
return posterior_predictive_samples
411383

412384
@property
413385
@abstractmethod

pymc_experimental/tests/test_model_builder.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -158,40 +158,22 @@ def test_predict():
158158
prediction_data = pd.DataFrame({"input": x_pred})
159159
pred = model.predict(prediction_data)
160160
assert "y_model" in pred
161-
assert isinstance(pred, dict)
162161
assert len(prediction_data.input.values) == len(pred["y_model"])
163-
assert isinstance(pred["y_model"][0], (np.float32, np.float64))
162+
assert np.issubdtype(pred["y_model"].dtype, np.floating)
164163

165164

166-
def test_predict_posterior():
165+
@pytest.mark.parametrize("combined", [True, False])
166+
def test_predict_posterior(combined):
167167
model = test_ModelBuilder.initial_build_and_fit()
168-
x_pred = np.random.uniform(low=0, high=1, size=100)
168+
n_pred = 100
169+
x_pred = np.random.uniform(low=0, high=1, size=n_pred)
169170
prediction_data = pd.DataFrame({"input": x_pred})
170-
pred = model.predict_posterior(prediction_data)
171-
assert "y_model" in pred
172-
assert isinstance(pred, dict)
173-
assert len(prediction_data.input.values) == len(pred["y_model"][0])
174-
assert isinstance(pred["y_model"][0], np.ndarray)
175-
176-
177-
def test_extract_samples():
178-
# create a fake InferenceData object
179-
with pm.Model() as model:
180-
x = pm.Normal("x", mu=0, sigma=1)
181-
intercept = pm.Normal("intercept", mu=0, sigma=1)
182-
y_model = pm.Normal("y_model", mu=x * intercept, sigma=1, observed=[0, 1, 2])
183-
184-
idata = pm.sample(1000, tune=1000)
185-
post_pred = pm.sample_posterior_predictive(idata)
186-
187-
# call the function and get the output
188-
samples_dict = test_ModelBuilder._extract_samples(post_pred)
189-
190-
# assert that the keys and values are correct
191-
assert len(samples_dict) == len(post_pred.posterior_predictive)
192-
for key in post_pred.posterior_predictive:
193-
expected_value = post_pred.posterior_predictive[key].to_numpy()[0]
194-
assert np.array_equal(samples_dict[key], expected_value)
171+
pred = model.predict_posterior(prediction_data, combined=combined)
172+
chains = model.idata.sample_stats.dims["chain"]
173+
draws = model.idata.sample_stats.dims["draw"]
174+
expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
175+
assert pred["y_model"].shape == expected_shape
176+
assert np.issubdtype(pred["y_model"].dtype, np.floating)
195177

196178

197179
def test_id():

0 commit comments

Comments
 (0)