Skip to content

Fix group selection in sample_posterior_predictive when predictions=True is passed in kwargs #426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
21 changes: 16 additions & 5 deletions pymc_extras/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def predict(
self,
X_pred: np.ndarray | pd.DataFrame | pd.Series,
extend_idata: bool = True,
predictions: bool = True,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -542,6 +543,9 @@ def predict(
The input data used for prediction.
extend_idata : Boolean determining whether the predictions should be added to inference data object.
Defaults to True.
predictions : bool
Whether to use the predictions group for posterior predictive sampling.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive

Returns
Expand All @@ -559,7 +563,7 @@ def predict(
"""

posterior_predictive_samples = self.sample_posterior_predictive(
X_pred, extend_idata, combined=False, **kwargs
X_pred, extend_idata, combined=False, predictions=predictions, **kwargs
)

if self.output_var not in posterior_predictive_samples:
Expand Down Expand Up @@ -624,7 +628,7 @@ def sample_prior_predictive(

return prior_predictive_samples

def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
def sample_posterior_predictive(self, X_pred, extend_idata, combined, predictions = True, **kwargs):
"""
Sample from the model's posterior predictive distribution.

Expand All @@ -634,6 +638,8 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
The input data used for prediction using prior distribution..
extend_idata : Boolean determining whether the predictions should be added to inference data object.
Defaults to False.
predictions : Boolean determing whether to use the predictions group for posterior predictive sampling.
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
Expand All @@ -646,12 +652,14 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
self._data_setter(X_pred)

with self.model: # sample with new input data
post_pred = pm.sample_posterior_predictive(self.idata, **kwargs)
post_pred = pm.sample_posterior_predictive(self.idata, predictions=predictions, **kwargs)
if extend_idata:
self.idata.extend(post_pred, join="right")

group_name = "predictions" if predictions else "posterior_predictive"

posterior_predictive_samples = az.extract(
post_pred, "posterior_predictive", combined=combined
post_pred, group_name, combined=combined
)

return posterior_predictive_samples
Expand Down Expand Up @@ -700,6 +708,7 @@ def predict_posterior(
X_pred: np.ndarray | pd.DataFrame | pd.Series,
extend_idata: bool = True,
combined: bool = True,
predictions: bool = True,
**kwargs,
) -> xr.DataArray:
"""
Expand All @@ -713,6 +722,8 @@ def predict_posterior(
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
predictions : Boolean determing whether to use the predictions group for posterior predictive sampling.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive

Returns
Expand All @@ -723,7 +734,7 @@ def predict_posterior(

X_pred = self._validate_data(X_pred)
posterior_predictive_samples = self.sample_posterior_predictive(
X_pred, extend_idata, combined, **kwargs
X_pred, extend_idata, combined, predictions=predictions, **kwargs
)

if self.output_var not in posterior_predictive_samples:
Expand Down