Skip to content

Commit ba018b7

Browse files
GoosericardoV94
Goose
authored andcommitted
add return type overload for sample_posterior_predictive
1 parent e9b3c99 commit ba018b7

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

pymc/sampling/forward.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,36 @@ def sample_prior_predictive(
493493
return pm.to_inference_data(prior=prior, **ikwargs)
494494

495495

496+
@overload
497+
def sample_posterior_predictive(
498+
trace,
499+
model: Model | None = None,
500+
var_names: list[str] | None = None,
501+
sample_dims: list[str] | None = None,
502+
random_seed: RandomState = None,
503+
progressbar: bool = True,
504+
progressbar_theme: Theme | None = default_progress_theme,
505+
return_inferencedata: Literal[True] = True,
506+
extend_inferencedata: bool = False,
507+
predictions: bool = False,
508+
idata_kwargs: dict | None = None,
509+
compile_kwargs: dict | None = None,
510+
) -> InferenceData: ...
511+
@overload
512+
def sample_posterior_predictive(
513+
trace,
514+
model: Model | None = None,
515+
var_names: list[str] | None = None,
516+
sample_dims: list[str] | None = None,
517+
random_seed: RandomState = None,
518+
progressbar: bool = True,
519+
progressbar_theme: Theme | None = default_progress_theme,
520+
return_inferencedata: Literal[False] = False,
521+
extend_inferencedata: bool = False,
522+
predictions: bool = False,
523+
idata_kwargs: dict | None = None,
524+
compile_kwargs: dict | None = None,
525+
) -> dict[str, np.ndarray]: ...
496526
def sample_posterior_predictive(
497527
trace,
498528
model: Model | None = None,

0 commit comments

Comments
 (0)