diff --git a/pymc_extras/model/model_api.py b/pymc_extras/model/model_api.py
index aeec8be77..9b84a4012 100644
--- a/pymc_extras/model/model_api.py
+++ b/pymc_extras/model/model_api.py
@@ -1,6 +1,9 @@
from functools import wraps
+from inspect import signature
-from pymc import Model
+import pytensor.tensor as pt
+
+from pymc import Data, Model
def as_model(*model_args, **model_kwargs):
@@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
Additionally, a coords argument is added to the function so coords can be changed during function invocation
+ All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
+
Adapted from `Rob Zinkov's blog post `_ and inspired by the `sampled `_ decorator for PyMC3.
Examples
@@ -47,8 +52,19 @@ def decorator(f):
@wraps(f)
def make_model(*args, **kwargs):
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
+ sig = signature(f)
+ ba = sig.bind(*args, **kwargs)
+ ba.apply_defaults()
+
with Model(*model_args, coords=coords, **model_kwargs) as m:
- f(*args, **kwargs)
+ for name, v in ba.arguments.items():
+ # Only wrap pm.Data around values pytensor can process
+ try:
+ _ = pt.as_tensor_variable(v)
+ ba.arguments[name] = Data(name, v)
+ except (NotImplementedError, TypeError, ValueError):
+ pass
+ f(*ba.args, **ba.kwargs)
return m
return make_model
diff --git a/tests/model/test_model_api.py b/tests/model/test_model_api.py
index d71d40bad..12e510adb 100644
--- a/tests/model/test_model_api.py
+++ b/tests/model/test_model_api.py
@@ -25,5 +25,14 @@ def model_wrapped2():
mw2 = model_wrapped2(coords=coords)
+ @pmx.as_model()
+ def model_wrapped3(mu):
+ pm.Normal("x", mu, 1.0, dims="obs")
+
+ mw3 = model_wrapped3(0.0, coords=coords)
+ mw4 = model_wrapped3(np.array([np.nan]), coords=coords)
+
np.testing.assert_equal(model.point_logps(), mw.point_logps())
np.testing.assert_equal(mw.point_logps(), mw2.point_logps())
+ assert mw3["mu"] in mw3.data_vars
+ assert "mu" not in mw4