-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement default_transform
and transform
argument for distributions
#7207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
8a50f9f
7434b9a
4fde32f
5c7519f
1dc0fcb
55b0edb
d791ac7
def0f6e
7d6ecf9
704aac6
23bd69f
35b44fe
7f0d1d3
80641f8
80fa510
655a364
995ec93
3eefa8e
8a005fe
8a708f7
412ef7f
48f5f85
bff4371
e0f26b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,7 @@ | |
|
||
from pymc.blocking import DictToArrayBijection, RaveledVars | ||
from pymc.data import GenTensorVariable, is_minibatch | ||
from pymc.distributions.transforms import _default_transform | ||
from pymc.distributions.transforms import ChainedTransform, _default_transform | ||
from pymc.exceptions import ( | ||
BlockModelAccessError, | ||
ImputationWarning, | ||
|
@@ -1214,7 +1214,16 @@ def set_data( | |
shared_object.set_value(values) | ||
|
||
def register_rv( | ||
self, rv_var, name, observed=None, total_size=None, dims=None, transform=UNSET, initval=None | ||
self, | ||
rv_var, | ||
name, | ||
*, | ||
observed=None, | ||
aerubanov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
total_size=None, | ||
dims=None, | ||
transform=UNSET, | ||
default_transform=UNSET, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing description in the docstrings |
||
initval=None, | ||
): | ||
"""Register an (un)observed random variable with the model. | ||
|
||
|
@@ -1255,7 +1264,7 @@ def register_rv( | |
if total_size is not None: | ||
raise ValueError("total_size can only be passed to observed RVs") | ||
self.free_RVs.append(rv_var) | ||
self.create_value_var(rv_var, transform) | ||
self.create_value_var(rv_var, transform=transform, default_transform=default_transform) | ||
self.add_named_variable(rv_var, dims) | ||
self.set_initval(rv_var, initval) | ||
else: | ||
|
@@ -1278,7 +1287,9 @@ def register_rv( | |
|
||
# `rv_var` is potentially changed by `make_obs_var`, | ||
# for example into a new graph for imputation of missing data. | ||
rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size) | ||
rv_var = self.make_obs_var( | ||
rv_var, observed, dims, transform, default_transform, total_size | ||
) | ||
|
||
return rv_var | ||
|
||
|
@@ -1288,6 +1299,7 @@ def make_obs_var( | |
data: np.ndarray, | ||
dims, | ||
transform: Any | None, | ||
default_transform: Any | None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing description in the docstrings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also type hint isn't great, should be |
||
total_size: int | None, | ||
) -> TensorVariable: | ||
"""Create a `TensorVariable` for an observed random variable. | ||
|
@@ -1339,12 +1351,19 @@ def make_obs_var( | |
|
||
# Register ObservedRV corresponding to observed component | ||
observed_rv.name = f"{name}_observed" | ||
self.create_value_var(observed_rv, transform=None, value_var=observed_data) | ||
self.create_value_var( | ||
observed_rv, transform=transform, default_transform=None, value_var=observed_data | ||
) | ||
self.add_named_variable(observed_rv) | ||
self.observed_RVs.append(observed_rv) | ||
|
||
# Register FreeRV corresponding to unobserved components | ||
self.register_rv(unobserved_rv, f"{name}_unobserved", transform=transform) | ||
self.register_rv( | ||
unobserved_rv, | ||
f"{name}_unobserved", | ||
transform=transform, | ||
default_transform=default_transform, | ||
) | ||
|
||
# Register Deterministic that combines observed and missing | ||
# Note: This can widely increase memory consumption during sampling for large datasets | ||
|
@@ -1363,14 +1382,21 @@ def make_obs_var( | |
rv_var.name = name | ||
|
||
rv_var.tag.observations = data | ||
self.create_value_var(rv_var, transform=None, value_var=data) | ||
self.create_value_var( | ||
rv_var, transform=transform, default_transform=None, value_var=data | ||
) | ||
self.add_named_variable(rv_var, dims) | ||
self.observed_RVs.append(rv_var) | ||
|
||
return rv_var | ||
|
||
def create_value_var( | ||
self, rv_var: TensorVariable, transform: Any, value_var: Variable | None = None | ||
self, | ||
rv_var: TensorVariable, | ||
*, | ||
transform: Any, | ||
aerubanov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
default_transform: Any, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update type hints and docstring |
||
value_var: Variable | None = None, | ||
) -> TensorVariable: | ||
"""Create a ``TensorVariable`` that will be used as the random | ||
variable's "value" in log-likelihood graphs. | ||
|
@@ -1396,11 +1422,16 @@ def create_value_var( | |
|
||
# Make the value variable a transformed value variable, | ||
# if there's an applicable transform | ||
if transform is UNSET: | ||
if default_transform is UNSET: | ||
if rv_var.owner is None: | ||
transform = None | ||
default_transform = None | ||
else: | ||
transform = _default_transform(rv_var.owner.op, rv_var) | ||
default_transform = _default_transform(rv_var.owner.op, rv_var) | ||
|
||
if transform is UNSET: | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
transform = default_transform | ||
elif transform and default_transform: | ||
aerubanov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
transform = ChainedTransform([default_transform, transform]) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if value_var is None: | ||
if transform is None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1359,17 +1359,25 @@ def test_warning(self): | |
|
||
with warnings.catch_warnings(): | ||
warnings.simplefilter("error") | ||
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None) | ||
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, default_transform=None) | ||
|
||
with pytest.warns( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This functionality shouldn't be tested here, since it's not specific to Mixture. Probably in model/test_core.py there should be related stuff already? I see you already have it there, so is this test needed? |
||
UserWarning, | ||
match="To disable default transform, please use default_transform=None" | ||
" instead of transform=None. Setting transform to None will not have" | ||
" any effect in future.", | ||
): | ||
Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, transform=None) | ||
|
||
with warnings.catch_warnings(): | ||
warnings.simplefilter("error") | ||
Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, observed=1) | ||
Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists, observed=1) | ||
|
||
# Case where the appropriate default transform is None | ||
comp_dists = [Normal.dist(), Normal.dist()] | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("error") | ||
Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists) | ||
Mixture("mix7", w=[0.5, 0.5], comp_dists=comp_dists) | ||
|
||
|
||
class TestZeroInflatedMixture: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -218,11 +218,15 @@ def test_interdependent_transformed_rvs(self, reversed): | |
transform = pm.distributions.transforms.Interval( | ||
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) | ||
) | ||
x = pm.Uniform("x", lower=0, upper=1, transform=transform) | ||
x = pm.Uniform("x", lower=0, upper=1, transform=transform, default_transform=None) | ||
# Operation between the variables provides a regression test for #7054 | ||
y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform) | ||
z = pm.Uniform("z", lower=0, upper=y, transform=transform) | ||
w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform) | ||
y = pm.Uniform( | ||
"y", lower=0, upper=pt.exp(x), transform=transform, default_transform=None | ||
) | ||
z = pm.Uniform("z", lower=0, upper=y, transform=transform, default_transform=None) | ||
w = pm.Uniform( | ||
"w", lower=0, upper=pt.square(z), transform=transform, default_transform=None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pass transform to |
||
) | ||
|
||
rvs = [x, y, z, w] | ||
if reversed: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ | |
from pymc.blocking import DictToArrayBijection, RaveledVars | ||
from pymc.distributions import Normal, transforms | ||
from pymc.distributions.distribution import PartialObservedRV | ||
from pymc.distributions.transforms import log, simplex | ||
from pymc.distributions.transforms import Transform, log, simplex | ||
from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning | ||
from pymc.logprob.basic import transformed_conditional_logp | ||
from pymc.logprob.transforms import IntervalTransform | ||
|
@@ -527,6 +527,42 @@ def test_model_var_maps(): | |
assert model.rvs_to_transforms[x] is None | ||
|
||
|
||
class TestTransformArgs: | ||
def test_transform_warning(self): | ||
with pm.Model(): | ||
with pytest.warns( | ||
UserWarning, | ||
match="To disable default transform," | ||
" please use default_transform=None" | ||
" instead of transform=None. Setting transform to" | ||
" None will not have any effect in future.", | ||
): | ||
a = pm.Normal("a", transform=None) | ||
|
||
def test_transform_order(self): | ||
transform_order = [] | ||
|
||
class DummyTransform(Transform): | ||
name = "dummy1" | ||
ndim_supp = 0 | ||
|
||
def __init__(self, marker) -> None: | ||
super().__init__() | ||
self.marker = marker | ||
|
||
def forward(self, value, *inputs): | ||
nonlocal transform_order | ||
transform_order.append(self.marker) | ||
return value | ||
|
||
def backward(self, value, *inputs): | ||
return value | ||
|
||
with pm.Model() as model: | ||
x = pm.Normal("x", transform=DummyTransform(2), default_transform=DummyTransform(1)) | ||
assert transform_order == [1, 2] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use regular transforms, and simply assert the obtained transform is a Chain, which as a property like Also, would be nice to include a numerical example that would have led to |
||
|
||
|
||
def test_make_obs_var(): | ||
""" | ||
Check returned values for `data` given known inputs to `as_tensor()`. | ||
|
@@ -549,26 +585,26 @@ def test_make_obs_var(): | |
|
||
# The function requires data and RV dimensionality to be compatible | ||
with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."): | ||
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None) | ||
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None, None) | ||
|
||
# Check function behavior using the various inputs | ||
# dense, sparse: Ensure that the missing values are appropriately set to None | ||
# masked: a deterministic variable is returned | ||
|
||
dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None) | ||
dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None, None) | ||
assert dense_output == fake_distribution | ||
assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant) | ||
del fake_model.named_vars[fake_distribution.name] | ||
|
||
sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None) | ||
sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None, None) | ||
assert sparse_output == fake_distribution | ||
assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output]) | ||
del fake_model.named_vars[fake_distribution.name] | ||
|
||
# Here the RandomVariable is split into observed/imputed and a Deterministic is returned | ||
with pytest.warns(ImputationWarning): | ||
masked_output = fake_model.make_obs_var( | ||
fake_distribution, masked_array_input, None, None, None | ||
fake_distribution, masked_array_input, None, None, None, None | ||
) | ||
assert masked_output != fake_distribution | ||
assert not isinstance(masked_output, RandomVariable) | ||
|
@@ -581,7 +617,7 @@ def test_make_obs_var(): | |
|
||
# Test that setting total_size returns a MinibatchRandomVariable | ||
scaled_outputs = fake_model.make_obs_var( | ||
fake_distribution, dense_input, None, None, total_size=100 | ||
fake_distribution, dense_input, None, None, None, total_size=100 | ||
) | ||
assert scaled_outputs != fake_distribution | ||
assert isinstance(scaled_outputs.owner.op, MinibatchRandomVariable) | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This warning should be in the relevant section in
pm.Model
instead ofDistribution