diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py
index 8425a764..29dbc938 100644
--- a/pymc_experimental/model/model_api.py
+++ b/pymc_experimental/model/model_api.py
@@ -7,6 +7,7 @@ def as_model(*model_args, **model_kwargs):
R"""
Decorator to provide context to PyMC models declared in a function.
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
+ Additionally, a coords argument is added to the function so coords can be changed during function invocation
Adapted from `Rob Zinkov's blog post `_ and inspired by the `sampled `_ decorator for PyMC3.
@@ -32,12 +33,21 @@ def basic_model():
m = basic_model()
pm.sample(model=m)
+ # alternative way to use functional API
+ @pmx.as_model()
+ def basic_model():
+ pm.Normal("x", 0., 1., dims="obs")
+
+ m = basic_model(coords={"obs": ["a", "b"]})
+ pm.sample(model=m)
+
"""
def decorator(f):
@wraps(f)
def make_model(*args, **kwargs):
- with Model(*model_args, **model_kwargs) as m:
+ coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
+ with Model(*model_args, coords=coords, **model_kwargs) as m:
f(*args, **kwargs)
return m
diff --git a/pymc_experimental/tests/model/test_model_api.py b/pymc_experimental/tests/model/test_model_api.py
index a36c5e5b..47a807fe 100644
--- a/pymc_experimental/tests/model/test_model_api.py
+++ b/pymc_experimental/tests/model/test_model_api.py
@@ -19,4 +19,11 @@ def model_wrapped():
mw = model_wrapped()
+ @pmx.as_model()
+ def model_wrapped2():
+ pm.Normal("x", 0.0, 1.0, dims="obs")
+
+ mw2 = model_wrapped2(coords=coords)
+
np.testing.assert_equal(model.point_logps(), mw.point_logps())
+ np.testing.assert_equal(mw.point_logps(), mw2.point_logps())