From 34839ad48996f504bdd6504d09712269b3a59e6a Mon Sep 17 00:00:00 2001 From: butterman0 Date: Mon, 17 Feb 2025 14:12:27 +0100 Subject: [PATCH 1/8] Fix group selection for posterior predictive samples when predictions = True --- pymc_extras/model_builder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 6e712e5d..98fbe7d5 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -650,8 +650,11 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): if extend_idata: self.idata.extend(post_pred, join="right") + # Determine the correct group dynamically + group_name = "predictions" if kwargs.get("predictions", False) else "posterior_predictive" + posterior_predictive_samples = az.extract( - post_pred, "posterior_predictive", combined=combined + post_pred, group_name, combined=combined ) return posterior_predictive_samples From 7cdd90050c57e559f0e9734ce485a1ed7c88382b Mon Sep 17 00:00:00 2001 From: butterman0 Date: Tue, 18 Feb 2025 12:23:49 +0100 Subject: [PATCH 2/8] refactor: make predictions argument explicit --- pymc_extras/model_builder.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 98fbe7d5..82fd07ec 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -530,6 +530,7 @@ def predict( self, X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, + predictions: bool = False, **kwargs, ) -> np.ndarray: """ @@ -559,7 +560,7 @@ def predict( """ posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, combined=False, **kwargs + X_pred, extend_idata, predictions, combined=False, **kwargs ) if self.output_var not in posterior_predictive_samples: @@ -624,7 +625,7 @@ def sample_prior_predictive( return prior_predictive_samples - def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): + def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combined, **kwargs): """ Sample from the model's posterior predictive distribution. @@ -646,12 +647,12 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): self._data_setter(X_pred) with self.model: # sample with new input data - post_pred = pm.sample_posterior_predictive(self.idata, **kwargs) + post_pred = pm.sample_posterior_predictive(self.idata, predictions=predictions, **kwargs) if extend_idata: self.idata.extend(post_pred, join="right") - # Determine the correct group dynamically - group_name = "predictions" if kwargs.get("predictions", False) else "posterior_predictive" + # Determine the correct group + group_name = "predictions" if predictions else "posterior_predictive" posterior_predictive_samples = az.extract( post_pred, group_name, combined=combined @@ -703,6 +704,7 @@ def predict_posterior( X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, combined: bool = True, + predictions: bool = False, **kwargs, ) -> xr.DataArray: """ @@ -726,7 +728,7 @@ def predict_posterior( X_pred = self._validate_data(X_pred) posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, combined, **kwargs + X_pred, extend_idata, predictions, combined, **kwargs ) if self.output_var not in posterior_predictive_samples: From a0501fe119ed988082438126c4e7af494d3c933d Mon Sep 17 00:00:00 2001 From: butterman0 Date: Tue, 18 Feb 2025 12:35:05 +0100 Subject: [PATCH 3/8] refactor: change default predictions to True --- pymc_extras/model_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 82fd07ec..2a4ab2ab 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -530,7 +530,7 @@ def predict( self, X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, - predictions: bool = False, + predictions: bool = True, **kwargs, ) -> np.ndarray: """ @@ -704,7 +704,7 @@ def predict_posterior( X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, combined: bool = True, - predictions: bool = False, + predictions: bool = True, **kwargs, ) -> xr.DataArray: """ From 079e1314a9671c12ce482a2cd9e2fa256bcf425c Mon Sep 17 00:00:00 2001 From: butterman0 Date: Tue, 18 Feb 2025 21:36:42 +0100 Subject: [PATCH 4/8] doc: update docstrings --- pymc_extras/model_builder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 2a4ab2ab..f8fc93de 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -543,6 +543,9 @@ def predict( The input data used for prediction. extend_idata : Boolean determining whether the predictions should be added to inference data object. Defaults to True. + predictions : bool + Whether to use the predictions group for posterior predictive sampling. + Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive Returns @@ -651,7 +654,6 @@ def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combine if extend_idata: self.idata.extend(post_pred, join="right") - # Determine the correct group group_name = "predictions" if predictions else "posterior_predictive" posterior_predictive_samples = az.extract( @@ -718,6 +720,8 @@ def predict_posterior( Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. + predictions : Boolean determing whether to use the predictions group for posterior predictive sampling. + Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive Returns From fb6b1e999e2eaf72102daed122bccef8ef5a7a5a Mon Sep 17 00:00:00 2001 From: butterman0 Date: Thu, 6 Mar 2025 17:01:41 +0100 Subject: [PATCH 5/8] docs: update --- pymc_extras/model_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index f8fc93de..dea962a3 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -638,6 +638,8 @@ def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combine The input data used for prediction using prior distribution.. extend_idata : Boolean determining whether the predictions should be added to inference data object. Defaults to False. + predictions : Boolean determing whether to use the predictions group for posterior predictive sampling. + Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive From ce1b2d5aed74eafe04483886ae8ef2c7909e2ddd Mon Sep 17 00:00:00 2001 From: butterman0 Date: Thu, 6 Mar 2025 17:05:12 +0100 Subject: [PATCH 6/8] refactor: pass predictions by keyword --- pymc_extras/model_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index dea962a3..ccc29c57 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -563,7 +563,7 @@ def predict( """ posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, predictions, combined=False, **kwargs + X_pred, extend_idata, combined=False, predictions=predictions, **kwargs ) if self.output_var not in posterior_predictive_samples: @@ -628,7 +628,7 @@ def sample_prior_predictive( return prior_predictive_samples - def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combined, **kwargs): + def sample_posterior_predictive(self, X_pred, extend_idata, combined, predictions = True, **kwargs): """ Sample from the model's posterior predictive distribution. @@ -734,7 +734,7 @@ def predict_posterior( X_pred = self._validate_data(X_pred) posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, predictions, combined, **kwargs + X_pred, extend_idata, combined, predictions=predictions, **kwargs ) if self.output_var not in posterior_predictive_samples: From 4ea5fbc42b412159cc3e7b85ee6045eef0635b5f Mon Sep 17 00:00:00 2001 From: butterman0 Date: Wed, 9 Apr 2025 15:50:24 +0200 Subject: [PATCH 7/8] test: added test for predictions grouping --- tests/test_model_builder.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 9494bb10..1a4ffa97 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -304,3 +304,35 @@ def test_id(): ).hexdigest()[:16] assert model_builder.id == expected_id + +@pytest.mark.parametrize("predictions", [True, False]) +def test_predict_respects_predictions_flag(fitted_model_instance, predictions): + x_pred = np.random.uniform(0, 1, 100) + prediction_data = pd.DataFrame({"input": x_pred}) + output_var = fitted_model_instance.output_var + + # Snapshot the original posterior_predictive values + pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy() + + # Ensure 'predictions' group is not present initially + assert "predictions" not in fitted_model_instance.idata.groups() + + # Run prediction with predictions=True or False + fitted_model_instance.predict( + prediction_data["input"], + extend_idata=True, + combined=False, + predictions=predictions, + ) + + pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values + + # Check predictions group presence + if predictions: + assert "predictions" in fitted_model_instance.idata.groups() + # Posterior predictive should remain unchanged + np.testing.assert_array_equal(pp_before, pp_after) + else: + assert "predictions" not in fitted_model_instance.idata.groups() + # Posterior predictive should be updated + np.testing.assert_array_not_equal(pp_before, pp_after) \ No newline at end of file From 7f84f034448d7e9e8307772d3573333964541831 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Apr 2025 07:41:41 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc_extras/model_builder.py | 12 +++++++----- tests/test_model_builder.py | 5 +++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index ccc29c57..10d34f29 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -628,7 +628,9 @@ def sample_prior_predictive( return prior_predictive_samples - def sample_posterior_predictive(self, X_pred, extend_idata, combined, predictions = True, **kwargs): + def sample_posterior_predictive( + self, X_pred, extend_idata, combined, predictions=True, **kwargs + ): """ Sample from the model's posterior predictive distribution. @@ -652,15 +654,15 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, prediction self._data_setter(X_pred) with self.model: # sample with new input data - post_pred = pm.sample_posterior_predictive(self.idata, predictions=predictions, **kwargs) + post_pred = pm.sample_posterior_predictive( + self.idata, predictions=predictions, **kwargs + ) if extend_idata: self.idata.extend(post_pred, join="right") group_name = "predictions" if predictions else "posterior_predictive" - posterior_predictive_samples = az.extract( - post_pred, group_name, combined=combined - ) + posterior_predictive_samples = az.extract(post_pred, group_name, combined=combined) return posterior_predictive_samples diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 1a4ffa97..88dd971d 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -305,6 +305,7 @@ def test_id(): assert model_builder.id == expected_id + @pytest.mark.parametrize("predictions", [True, False]) def test_predict_respects_predictions_flag(fitted_model_instance, predictions): x_pred = np.random.uniform(0, 1, 100) @@ -324,7 +325,7 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions): combined=False, predictions=predictions, ) - + pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values # Check predictions group presence @@ -335,4 +336,4 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions): else: assert "predictions" not in fitted_model_instance.idata.groups() # Posterior predictive should be updated - np.testing.assert_array_not_equal(pp_before, pp_after) \ No newline at end of file + np.testing.assert_array_not_equal(pp_before, pp_after)