diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index ee2461062..f6bcd3663 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -163,3 +163,9 @@ def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg): test_prior = pm.sample_prior_predictive(1) names = [p["name"] for p in param_cfg.values()] assert set(model.named_vars) == {"trace_prior_", *names} + + +def test_empty(idata, coords): + with pm.Model(coords=coords): + priors = pmx.utils.prior.prior_from_idata(idata) + assert not priors diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 73a3b00b7..962b01bc0 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -132,7 +132,7 @@ def prior_from_idata( idata: arviz.InferenceData, name="trace_prior_", *, - var_names: Sequence[str], + var_names: Sequence[str] = (), **kwargs: Union[ParamCfg, RVTransform, str, Tuple] ) -> Dict[str, pt.TensorVariable]: """ @@ -192,5 +192,7 @@ def prior_from_idata( ... trace1 = pm.sample_prior_predictive(100) """ param_cfg = _parse_args(var_names=var_names, **kwargs) + if not param_cfg: + return {} flat_info = _flatten(idata, **param_cfg) return _mvn_prior_from_flat_info(name, flat_info)