diff --git a/pymc/sampling.py b/pymc/sampling.py index df3dcdd213..15609eac59 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -2032,8 +2032,6 @@ def sample_posterior_predictive_w( weighted models (default), or a dictionary with variable names as keys, and samples as numpy arrays. """ - raise NotImplementedError(f"sample_posterior_predictive_w has not yet been ported to PyMC 4.0.") - if isinstance(traces[0], InferenceData): n_samples = [ trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces @@ -2140,13 +2138,13 @@ def sample_posterior_predictive_w( # TODO sample_posterior_predictive_w is currently only work for model with # one observed. # XXX: This needs to be refactored - # ppc[var.name].append(draw_values([var], point=param, size=size[idx])[0]) - raise NotImplementedError() + ppcl[var.name].append(draw([var])[0]) except KeyboardInterrupt: pass else: ppcd = {k: np.asarray(v) for k, v in ppcl.items()} + return ppcd if not return_inferencedata: return ppcd ikwargs: Dict[str, Any] = dict(model=models) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index f286581baa..799121a50e 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -935,9 +935,6 @@ def test_deterministics_out_of_idata(self, multitrace): assert np.all(np.abs(ppc.posterior_predictive.c + 4) <= 0.1) -@pytest.mark.xfail( - reason="sample_posterior_predictive_w not refactored for v4", raises=NotImplementedError -) class TestSamplePPCW(SeededTest): def test_sample_posterior_predictive_w(self): data0 = np.random.normal(0, 1, size=50)