Skip to content

as_model wraps function arguments in Data if they support it. #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions pymc_extras/model/model_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from functools import wraps
from inspect import signature

from pymc import Model
import pytensor.tensor as pt

from pymc import Data, Model


def as_model(*model_args, **model_kwargs):
Expand All @@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
Additionally, a coords argument is added to the function so coords can be changed during function invocation

All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.

Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.

Examples
Expand Down Expand Up @@ -47,8 +52,19 @@ def decorator(f):
@wraps(f)
def make_model(*args, **kwargs):
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
sig = signature(f)
ba = sig.bind(*args, **kwargs)
ba.apply_defaults()

with Model(*model_args, coords=coords, **model_kwargs) as m:
f(*args, **kwargs)
for name, v in ba.arguments.items():
# Only wrap pm.Data around values pytensor can process
try:
_ = pt.as_tensor_variable(v)
ba.arguments[name] = Data(name, v)
except (NotImplementedError, TypeError, ValueError):
pass
f(*ba.args, **ba.kwargs)
return m

return make_model
Expand Down
9 changes: 9 additions & 0 deletions tests/model/test_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,14 @@ def model_wrapped2():

mw2 = model_wrapped2(coords=coords)

@pmx.as_model()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an explicitly test with numpy array with nan which should fail when wrapped in pm.Data?

def model_wrapped3(mu):
pm.Normal("x", mu, 1.0, dims="obs")

mw3 = model_wrapped3(0.0, coords=coords)
mw4 = model_wrapped3(np.array([np.nan]), coords=coords)

np.testing.assert_equal(model.point_logps(), mw.point_logps())
np.testing.assert_equal(mw.point_logps(), mw2.point_logps())
assert mw3["mu"] in mw3.data_vars
assert "mu" not in mw4
Loading