Skip to content

Commit ce1b2d5

Browse files
committed
refactor: pass predictions by keyword
1 parent fb6b1e9 commit ce1b2d5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc_extras/model_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def predict(
563563
"""
564564

565565
posterior_predictive_samples = self.sample_posterior_predictive(
566-
X_pred, extend_idata, predictions, combined=False, **kwargs
566+
X_pred, extend_idata, combined=False, predictions=predictions, **kwargs
567567
)
568568

569569
if self.output_var not in posterior_predictive_samples:
@@ -628,7 +628,7 @@ def sample_prior_predictive(
628628

629629
return prior_predictive_samples
630630

631-
def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combined, **kwargs):
631+
def sample_posterior_predictive(self, X_pred, extend_idata, combined, predictions = True, **kwargs):
632632
"""
633633
Sample from the model's posterior predictive distribution.
634634
@@ -734,7 +734,7 @@ def predict_posterior(
734734

735735
X_pred = self._validate_data(X_pred)
736736
posterior_predictive_samples = self.sample_posterior_predictive(
737-
X_pred, extend_idata, predictions, combined, **kwargs
737+
X_pred, extend_idata, combined, predictions=predictions, **kwargs
738738
)
739739

740740
if self.output_var not in posterior_predictive_samples:

0 commit comments

Comments
 (0)