Skip to content

Deprecate LoosePointFunc and make FastPointFunc the default #5318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ricardoV94 opened this issue Jan 7, 2022 Discussed in #5237 · 0 comments · Fixed by #5320
Closed

Deprecate LoosePointFunc and make FastPointFunc the default #5318

ricardoV94 opened this issue Jan 7, 2022 Discussed in #5237 · 0 comments · Fixed by #5320

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 7, 2022

Discussed in #5237

Originally posted by ricardoV94 December 3, 2021
It seems that the difference is that LoosePointFunc allows you to evaluate a logp/dlogp/ what have you with args and kwargs, instead of just passing a dictionary, but at the cost of being slower because it has to wrap them in a dictionary (Point) before calling logp.

import pymc as pm

with pm.Model() as m:
    x = pm.Normal('x')
    y = pm.Normal('y', 0, 1, observed=0)

# Uses `LoosePointFunc` under the hood
m.logp({'x': 0})
m.logp(x=0)
m.logp([('x', 0)])  # Couldn't think of any other way of using `args`

# Uses `FastPointFunc` under the hood
m.fastlogp({'x': 0})
m.fastlogp(x=0) # Fails
m.fastlogp([('x', 0)]) # Fails

pymc/pymc/model.py

Lines 1722 to 1766 in f6f1a8e

def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
"""Build a point. Uses same args as dict() does.
Filters out variables not in the model. All keys are strings.
Parameters
----------
args, kwargs
arguments to build a dict
filter_model_vars : bool
If `True`, only model variables are included in the result.
"""
model = modelcontext(kwargs.pop("model", None))
args = list(args)
try:
d = dict(*args, **kwargs)
except Exception as e:
raise TypeError(f"can't turn {args} and {kwargs} into a dict. {e}")
return {
get_var_name(k): np.array(v)
for k, v in d.items()
if not filter_model_vars or (get_var_name(k) in map(get_var_name, model.value_vars))
}
class FastPointFunc:
"""Wraps so a function so it takes a dict of arguments instead of arguments."""
def __init__(self, f):
self.f = f
def __call__(self, state):
return self.f(**state)
class LoosePointFunc:
"""Wraps so a function so it takes a dict of arguments instead of arguments
but can still take arguments."""
def __init__(self, f, model):
self.f = f
self.model = model
def __call__(self, *args, **kwargs):
point = Point(model=self.model, *args, filter_model_vars=True, **kwargs)
return self.f(**point)

Since both attributes do the costly job of compiling a new logp function, we could also deprecate the fastlogp name and rename it to something like compile_logp, similar to compute_initial_point, which gives a better intuition that you should probably cache the output of this, and not call it in a loop for instance

@ricardoV94 ricardoV94 changed the title Deprecate LoosePointFunc and make FastPointFunc the default. Deprecate LoosePointFunc and make FastPointFunc the default Jan 7, 2022
@ricardoV94 ricardoV94 self-assigned this Jan 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant