Skip to content

Commit 8046695

Browse files
theorashidzaxtax
andauthored
Add @as_model decorator (#268)
Co-authored-by: Rob Zinkov <[email protected]>
1 parent 5fc0463 commit 8046695

File tree

8 files changed

+73
-3
lines changed

8 files changed

+73
-3
lines changed

docs/api_reference.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ methods in the current release of PyMC experimental.
88
.. autosummary::
99
:toctree: generated/
1010

11-
marginal_model.MarginalModel
11+
as_model
12+
MarginalModel
1213
model_builder.ModelBuilder
1314

1415
Inference

pymc_experimental/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@
2525

2626
from pymc_experimental import distributions, gp, utils
2727
from pymc_experimental.inference.fit import fit
28-
from pymc_experimental.marginal_model import MarginalModel
28+
from pymc_experimental.model.marginal_model import MarginalModel
29+
from pymc_experimental.model.model_api import as_model

pymc_experimental/model/__init__.py

Whitespace-only changes.

pymc_experimental/model/model_api.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from functools import wraps
2+
3+
from pymc import Model
4+
5+
6+
def as_model(*model_args, **model_kwargs):
7+
R"""
8+
Decorator to provide context to PyMC models declared in a function.
9+
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
10+
11+
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
12+
13+
Examples
14+
--------
15+
.. code:: python
16+
17+
import pymc as pm
18+
import pymc_experimental as pmx
19+
20+
# The following are equivalent
21+
22+
# standard PyMC API with context manager
23+
with pm.Model(coords={"obs": ["a", "b"]}) as model:
24+
x = pm.Normal("x", 0., 1., dims="obs")
25+
pm.sample()
26+
27+
# functional API using decorator
28+
@pmx.as_model(coords={"obs": ["a", "b"]})
29+
def basic_model():
30+
pm.Normal("x", 0., 1., dims="obs")
31+
32+
m = basic_model()
33+
pm.sample(model=m)
34+
35+
"""
36+
37+
def decorator(f):
38+
@wraps(f)
39+
def make_model(*args, **kwargs):
40+
with Model(*model_args, **model_kwargs) as m:
41+
f(*args, **kwargs)
42+
return m
43+
44+
return make_model
45+
46+
return decorator

pymc_experimental/tests/model/__init__.py

Whitespace-only changes.

pymc_experimental/tests/test_marginal_model.py renamed to pymc_experimental/tests/model/test_marginal_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pymc.util import UNSET
1313
from scipy.special import logsumexp
1414

15-
from pymc_experimental.marginal_model import (
15+
from pymc_experimental.model.marginal_model import (
1616
FiniteDiscreteMarginalRV,
1717
MarginalModel,
1818
is_conditional_dependent,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
import pymc as pm
3+
4+
import pymc_experimental as pmx
5+
6+
7+
def test_logp():
8+
"""Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator
9+
and a functional syntax. Checks whether the kwarg `coords` can be passed.
10+
"""
11+
coords = {"obs": ["a", "b"]}
12+
13+
with pm.Model(coords=coords) as model:
14+
pm.Normal("x", 0.0, 1.0, dims="obs")
15+
16+
@pmx.as_model(coords=coords)
17+
def model_wrapped():
18+
pm.Normal("x", 0.0, 1.0, dims="obs")
19+
20+
mw = model_wrapped()
21+
22+
np.testing.assert_equal(model.point_logps(), mw.point_logps())

0 commit comments

Comments
 (0)