From db47146c9ae284db8457574f68fdbef7c676a02f Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Thu, 14 Dec 2023 03:19:18 +0100 Subject: [PATCH] Pass coords argument into model factory --- pymc_experimental/model/model_api.py | 12 +++++++++++- pymc_experimental/tests/model/test_model_api.py | 7 +++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/model/model_api.py b/pymc_experimental/model/model_api.py index 8425a764..29dbc938 100644 --- a/pymc_experimental/model/model_api.py +++ b/pymc_experimental/model/model_api.py @@ -7,6 +7,7 @@ def as_model(*model_args, **model_kwargs): R""" Decorator to provide context to PyMC models declared in a function. This removes all need to think about context managers and lets you separate creating a generative model from using the model. + Additionally, a coords argument is added to the function so coords can be changed during function invocation Adapted from `Rob Zinkov's blog post `_ and inspired by the `sampled `_ decorator for PyMC3. @@ -32,12 +33,21 @@ def basic_model(): m = basic_model() pm.sample(model=m) + # alternative way to use functional API + @pmx.as_model() + def basic_model(): + pm.Normal("x", 0., 1., dims="obs") + + m = basic_model(coords={"obs": ["a", "b"]}) + pm.sample(model=m) + """ def decorator(f): @wraps(f) def make_model(*args, **kwargs): - with Model(*model_args, **model_kwargs) as m: + coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {}) + with Model(*model_args, coords=coords, **model_kwargs) as m: f(*args, **kwargs) return m diff --git a/pymc_experimental/tests/model/test_model_api.py b/pymc_experimental/tests/model/test_model_api.py index a36c5e5b..47a807fe 100644 --- a/pymc_experimental/tests/model/test_model_api.py +++ b/pymc_experimental/tests/model/test_model_api.py @@ -19,4 +19,11 @@ def model_wrapped(): mw = model_wrapped() + @pmx.as_model() + def model_wrapped2(): + pm.Normal("x", 0.0, 1.0, dims="obs") + + mw2 = model_wrapped2(coords=coords) + np.testing.assert_equal(model.point_logps(), mw.point_logps()) + np.testing.assert_equal(mw.point_logps(), mw2.point_logps())