Skip to content

Return posterior predictive samples from all chains in ModelBuilder #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 19 additions & 52 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr


class ModelBuilder:
Expand Down Expand Up @@ -244,13 +245,13 @@ def fit(
self.idata.add_groups(fit_data=self.data.to_xarray())
return self.idata

def predict(
def predict_posterior(
self,
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
extend_idata: bool = True,
) -> dict:
) -> xr.Dataset:
"""
Uses model to predict on unseen data and return point prediction of all the samples
Generate posterior predictive samples on unseen data.

Parameters
---------
Expand All @@ -261,17 +262,16 @@ def predict(

Returns
-------
returns dictionary of sample's mean of posterior predict.
returns posterior predictive samples

Examples
--------
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
>>> model = LinearModel(model_config, sampler_config)
>>> idata = model.fit(data)
>>> x_pred = []
>>> prediction_data = pd.DataFrame({'input':x_pred})
# point predict
>>> pred_mean = model.predict(prediction_data)
>>> prediction_data = pd.DataFrame({'input': x_pred})
>>> pred_samples = model.predict_posterior(prediction_data)
"""

if data_prediction is not None: # set new input data
Expand All @@ -281,20 +281,18 @@ def predict(
post_pred = pm.sample_posterior_predictive(self.idata)
if extend_idata:
self.idata.extend(post_pred)
# reshape output
post_pred = self._extract_samples(post_pred)
for key in post_pred:
post_pred[key] = post_pred[key].mean(axis=0)

return post_pred
posterior_predictive_samples = az.extract(post_pred, "posterior_predictive", combined=False)

def predict_posterior(
return posterior_predictive_samples

def predict(
self,
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
extend_idata: bool = True,
) -> Dict[str, np.array]:
) -> xr.Dataset:
"""
Uses model to predict samples on unseen data.
Uses model to predict on unseen data and return point prediction of all the samples

Parameters
---------
Expand All @@ -305,51 +303,20 @@ def predict_posterior(

Returns
-------
returns dictionary of sample's posterior predict.
returns posterior mean of predictive samples

Examples
--------
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
>>> model = LinearModel(model_config, sampler_config)
>>> idata = model.fit(data)
>>> x_pred = []
>>> prediction_data = pd.DataFrame({'input': x_pred})
# samples
>>> pred_mean = model.predict_posterior(prediction_data)
"""

if data_prediction is not None: # set new input data
self._data_setter(data_prediction)

with self.model: # sample with new input data
post_pred = pm.sample_posterior_predictive(self.idata)
if extend_idata:
self.idata.extend(post_pred)

# reshape output
post_pred = self._extract_samples(post_pred)

return post_pred

@staticmethod
def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[str, np.array]:
"""
This method can be used to extract samples from posterior predict.

Parameters
----------
post_pred: arviz InferenceData object

Returns
-------
Dictionary of numpy arrays from InferenceData object
>>> prediction_data = pd.DataFrame({'input':x_pred})
>>> pred_mean = model.predict(prediction_data)
"""

post_pred_dict = dict()
for key in post_pred.posterior_predictive:
post_pred_dict[key] = post_pred.posterior_predictive[key].to_numpy()[0]

return post_pred_dict
posterior_predictive_samples = self.predict_posterior(data_prediction, extend_idata)
posterior_means = posterior_predictive_samples.mean(dim=["chain", "draw"], keep_attrs=True)
return posterior_means

@property
def id(self) -> str:
Expand Down
34 changes: 7 additions & 27 deletions pymc_experimental/tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class test_ModelBuilder(ModelBuilder):
version = "0.1"

def build(self):

with pm.Model() as self.model:
if self.data is not None:
x = pm.MutableData("x", self.data["input"].values)
Expand Down Expand Up @@ -148,40 +147,21 @@ def test_predict():
prediction_data = pd.DataFrame({"input": x_pred})
pred = model.predict(prediction_data)
assert "y_model" in pred
assert isinstance(pred, dict)
assert len(prediction_data.input.values) == len(pred["y_model"])
assert isinstance(pred["y_model"][0], (np.float32, np.float64))
assert np.issubdtype(pred["y_model"].dtype, np.floating)


def test_predict_posterior():
model = test_ModelBuilder.initial_build_and_fit()
x_pred = np.random.uniform(low=0, high=1, size=100)
n_pred = 100
x_pred = np.random.uniform(low=0, high=1, size=n_pred)
prediction_data = pd.DataFrame({"input": x_pred})
pred = model.predict_posterior(prediction_data)
chains = model.idata.sample_stats.dims["chain"]
draws = model.idata.sample_stats.dims["draw"]
assert "y_model" in pred
assert isinstance(pred, dict)
assert len(prediction_data.input.values) == len(pred["y_model"][0])
assert isinstance(pred["y_model"][0], np.ndarray)


def test_extract_samples():
# create a fake InferenceData object
with pm.Model() as model:
x = pm.Normal("x", mu=0, sigma=1)
intercept = pm.Normal("intercept", mu=0, sigma=1)
y_model = pm.Normal("y_model", mu=x * intercept, sigma=1, observed=[0, 1, 2])

idata = pm.sample(1000, tune=1000)
post_pred = pm.sample_posterior_predictive(idata)

# call the function and get the output
samples_dict = test_ModelBuilder._extract_samples(post_pred)

# assert that the keys and values are correct
assert len(samples_dict) == len(post_pred.posterior_predictive)
for key in post_pred.posterior_predictive:
expected_value = post_pred.posterior_predictive[key].to_numpy()[0]
assert np.array_equal(samples_dict[key], expected_value)
assert pred["y_model"].shape == (chains, draws, n_pred)
assert np.issubdtype(pred["y_model"].dtype, np.floating)


def test_id():
Expand Down