diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 6e712e5d..10d34f29 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -530,6 +530,7 @@ def predict( self, X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, + predictions: bool = True, **kwargs, ) -> np.ndarray: """ @@ -542,6 +543,9 @@ def predict( The input data used for prediction. extend_idata : Boolean determining whether the predictions should be added to inference data object. Defaults to True. + predictions : bool + Whether to use the predictions group for posterior predictive sampling. + Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive Returns @@ -559,7 +563,7 @@ def predict( """ posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, combined=False, **kwargs + X_pred, extend_idata, combined=False, predictions=predictions, **kwargs ) if self.output_var not in posterior_predictive_samples: @@ -624,7 +628,9 @@ def sample_prior_predictive( return prior_predictive_samples - def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): + def sample_posterior_predictive( + self, X_pred, extend_idata, combined, predictions=True, **kwargs + ): """ Sample from the model's posterior predictive distribution. @@ -634,6 +640,8 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): The input data used for prediction using prior distribution.. extend_idata : Boolean determining whether the predictions should be added to inference data object. Defaults to False. + predictions : Boolean determing whether to use the predictions group for posterior predictive sampling. + Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive @@ -646,13 +654,15 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): self._data_setter(X_pred) with self.model: # sample with new input data - post_pred = pm.sample_posterior_predictive(self.idata, **kwargs) + post_pred = pm.sample_posterior_predictive( + self.idata, predictions=predictions, **kwargs + ) if extend_idata: self.idata.extend(post_pred, join="right") - posterior_predictive_samples = az.extract( - post_pred, "posterior_predictive", combined=combined - ) + group_name = "predictions" if predictions else "posterior_predictive" + + posterior_predictive_samples = az.extract(post_pred, group_name, combined=combined) return posterior_predictive_samples @@ -700,6 +710,7 @@ def predict_posterior( X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, combined: bool = True, + predictions: bool = True, **kwargs, ) -> xr.DataArray: """ @@ -713,6 +724,8 @@ def predict_posterior( Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. + predictions : Boolean determing whether to use the predictions group for posterior predictive sampling. + Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive Returns @@ -723,7 +736,7 @@ def predict_posterior( X_pred = self._validate_data(X_pred) posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, combined, **kwargs + X_pred, extend_idata, combined, predictions=predictions, **kwargs ) if self.output_var not in posterior_predictive_samples: diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 9494bb10..88dd971d 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -304,3 +304,36 @@ def test_id(): ).hexdigest()[:16] assert model_builder.id == expected_id + + +@pytest.mark.parametrize("predictions", [True, False]) +def test_predict_respects_predictions_flag(fitted_model_instance, predictions): + x_pred = np.random.uniform(0, 1, 100) + prediction_data = pd.DataFrame({"input": x_pred}) + output_var = fitted_model_instance.output_var + + # Snapshot the original posterior_predictive values + pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy() + + # Ensure 'predictions' group is not present initially + assert "predictions" not in fitted_model_instance.idata.groups() + + # Run prediction with predictions=True or False + fitted_model_instance.predict( + prediction_data["input"], + extend_idata=True, + combined=False, + predictions=predictions, + ) + + pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values + + # Check predictions group presence + if predictions: + assert "predictions" in fitted_model_instance.idata.groups() + # Posterior predictive should remain unchanged + np.testing.assert_array_equal(pp_before, pp_after) + else: + assert "predictions" not in fitted_model_instance.idata.groups() + # Posterior predictive should be updated + np.testing.assert_array_not_equal(pp_before, pp_after)