diff --git a/pymc_extras/model/model_api.py b/pymc_extras/model/model_api.py index aeec8be77..9b84a4012 100644 --- a/pymc_extras/model/model_api.py +++ b/pymc_extras/model/model_api.py @@ -1,6 +1,9 @@ from functools import wraps +from inspect import signature -from pymc import Model +import pytensor.tensor as pt + +from pymc import Data, Model def as_model(*model_args, **model_kwargs): @@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs): This removes all need to think about context managers and lets you separate creating a generative model from using the model. Additionally, a coords argument is added to the function so coords can be changed during function invocation + All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it. + Adapted from `Rob Zinkov's blog post `_ and inspired by the `sampled `_ decorator for PyMC3. Examples @@ -47,8 +52,19 @@ def decorator(f): @wraps(f) def make_model(*args, **kwargs): coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {}) + sig = signature(f) + ba = sig.bind(*args, **kwargs) + ba.apply_defaults() + with Model(*model_args, coords=coords, **model_kwargs) as m: - f(*args, **kwargs) + for name, v in ba.arguments.items(): + # Only wrap pm.Data around values pytensor can process + try: + _ = pt.as_tensor_variable(v) + ba.arguments[name] = Data(name, v) + except (NotImplementedError, TypeError, ValueError): + pass + f(*ba.args, **ba.kwargs) return m return make_model diff --git a/tests/model/test_model_api.py b/tests/model/test_model_api.py index d71d40bad..12e510adb 100644 --- a/tests/model/test_model_api.py +++ b/tests/model/test_model_api.py @@ -25,5 +25,14 @@ def model_wrapped2(): mw2 = model_wrapped2(coords=coords) + @pmx.as_model() + def model_wrapped3(mu): + pm.Normal("x", mu, 1.0, dims="obs") + + mw3 = model_wrapped3(0.0, coords=coords) + mw4 = model_wrapped3(np.array([np.nan]), coords=coords) + np.testing.assert_equal(model.point_logps(), mw.point_logps()) np.testing.assert_equal(mw.point_logps(), mw2.point_logps()) + assert mw3["mu"] in mw3.data_vars + assert "mu" not in mw4