Skip to content

Implement default_transform and transform argument for distributions #7207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __new__(
observed=None,
total_size=None,
transform=UNSET,
default_transform=UNSET,
**kwargs,
) -> TensorVariable:
"""Adds a tensor variable corresponding to a PyMC distribution to the current model.
Expand Down Expand Up @@ -397,6 +398,15 @@ def __new__(
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

if transform is None and default_transform is UNSET:
Copy link
Member

@ricardoV94 ricardoV94 Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning should be in the relevant section in pm.Model instead of Distribution

default_transform = None
warnings.warn(
"To disable default transform, please use default_transform=None"
" instead of transform=None. Setting transform to None will"
" not have any effect in future.",
UserWarning,
)

dims = convert_dims(dims)
if observed is not None:
observed = convert_observed_data(observed)
Expand All @@ -414,10 +424,11 @@ def __new__(
rv_out = model.register_rv(
rv_out,
name,
observed,
total_size,
observed=observed,
total_size=total_size,
dims=dims,
transform=transform,
default_transform=default_transform,
initval=initval,
)

Expand Down
53 changes: 42 additions & 11 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import GenTensorVariable, is_minibatch
from pymc.distributions.transforms import _default_transform
from pymc.distributions.transforms import ChainedTransform, _default_transform
from pymc.exceptions import (
BlockModelAccessError,
ImputationWarning,
Expand Down Expand Up @@ -1214,7 +1214,16 @@ def set_data(
shared_object.set_value(values)

def register_rv(
self, rv_var, name, observed=None, total_size=None, dims=None, transform=UNSET, initval=None
self,
rv_var,
name,
*,
observed=None,
total_size=None,
dims=None,
transform=UNSET,
default_transform=UNSET,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing description in the docstrings

initval=None,
):
"""Register an (un)observed random variable with the model.

Expand Down Expand Up @@ -1255,7 +1264,7 @@ def register_rv(
if total_size is not None:
raise ValueError("total_size can only be passed to observed RVs")
self.free_RVs.append(rv_var)
self.create_value_var(rv_var, transform)
self.create_value_var(rv_var, transform=transform, default_transform=default_transform)
self.add_named_variable(rv_var, dims)
self.set_initval(rv_var, initval)
else:
Expand All @@ -1278,7 +1287,9 @@ def register_rv(

# `rv_var` is potentially changed by `make_obs_var`,
# for example into a new graph for imputation of missing data.
rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)
rv_var = self.make_obs_var(
rv_var, observed, dims, transform, default_transform, total_size
)

return rv_var

Expand All @@ -1288,6 +1299,7 @@ def make_obs_var(
data: np.ndarray,
dims,
transform: Any | None,
default_transform: Any | None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing description in the docstrings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also type hint isn't great, should be RVTransform | None (or something like that, don't remember the exact class now)

total_size: int | None,
) -> TensorVariable:
"""Create a `TensorVariable` for an observed random variable.
Expand Down Expand Up @@ -1339,12 +1351,19 @@ def make_obs_var(

# Register ObservedRV corresponding to observed component
observed_rv.name = f"{name}_observed"
self.create_value_var(observed_rv, transform=None, value_var=observed_data)
self.create_value_var(
observed_rv, transform=transform, default_transform=None, value_var=observed_data
)
self.add_named_variable(observed_rv)
self.observed_RVs.append(observed_rv)

# Register FreeRV corresponding to unobserved components
self.register_rv(unobserved_rv, f"{name}_unobserved", transform=transform)
self.register_rv(
unobserved_rv,
f"{name}_unobserved",
transform=transform,
default_transform=default_transform,
)

# Register Deterministic that combines observed and missing
# Note: This can widely increase memory consumption during sampling for large datasets
Expand All @@ -1363,14 +1382,21 @@ def make_obs_var(
rv_var.name = name

rv_var.tag.observations = data
self.create_value_var(rv_var, transform=None, value_var=data)
self.create_value_var(
rv_var, transform=transform, default_transform=None, value_var=data
)
self.add_named_variable(rv_var, dims)
self.observed_RVs.append(rv_var)

return rv_var

def create_value_var(
self, rv_var: TensorVariable, transform: Any, value_var: Variable | None = None
self,
rv_var: TensorVariable,
*,
transform: Any,
default_transform: Any,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update type hints and docstring

value_var: Variable | None = None,
) -> TensorVariable:
"""Create a ``TensorVariable`` that will be used as the random
variable's "value" in log-likelihood graphs.
Expand All @@ -1396,11 +1422,16 @@ def create_value_var(

# Make the value variable a transformed value variable,
# if there's an applicable transform
if transform is UNSET:
if default_transform is UNSET:
if rv_var.owner is None:
transform = None
default_transform = None
else:
transform = _default_transform(rv_var.owner.op, rv_var)
default_transform = _default_transform(rv_var.owner.op, rv_var)

if transform is UNSET:
transform = default_transform
elif transform and default_transform:
transform = ChainedTransform([default_transform, transform])

if value_var is None:
if transform is None:
Expand Down
6 changes: 4 additions & 2 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,14 @@ def first_non_model_var(var):
var, value, *dims = model_var.owner.inputs
transform = model_var.owner.op.transform
model.free_RVs.append(var)
model.create_value_var(var, transform=transform, value_var=value)
model.create_value_var(
var, transform=transform, default_transform=None, value_var=value
)
model.set_initval(var, initval=None)
elif isinstance(model_var.owner.op, ModelObservedRV):
var, value, *dims = model_var.owner.inputs
model.observed_RVs.append(var)
model.create_value_var(var, transform=None, value_var=value)
model.create_value_var(var, transform=None, default_transform=None, value_var=value)
elif isinstance(model_var.owner.op, ModelPotential):
var, *dims = model_var.owner.inputs
model.potentials.append(var)
Expand Down
14 changes: 11 additions & 3 deletions tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,17 +1359,25 @@ def test_warning(self):

with warnings.catch_warnings():
warnings.simplefilter("error")
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, default_transform=None)

with pytest.warns(
Copy link
Member

@ricardoV94 ricardoV94 Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functionality shouldn't be tested here, since it's not specific to Mixture. Probably in model/test_core.py there should be related stuff already? I see you already have it there, so is this test needed?

UserWarning,
match="To disable default transform, please use default_transform=None"
" instead of transform=None. Setting transform to None will not have"
" any effect in future.",
):
Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)

with warnings.catch_warnings():
warnings.simplefilter("error")
Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, observed=1)
Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists, observed=1)

# Case where the appropriate default transform is None
comp_dists = [Normal.dist(), Normal.dist()]
with warnings.catch_warnings():
warnings.simplefilter("error")
Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists)
Mixture("mix7", w=[0.5, 0.5], comp_dists=comp_dists)


class TestZeroInflatedMixture:
Expand Down
4 changes: 2 additions & 2 deletions tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def test_transform_univariate_dist_logp_shape():

def test_univariate_transform_multivariate_dist_raises():
with pm.Model() as m:
pm.Dirichlet("x", [1, 1, 1], transform=tr.log)
pm.Dirichlet("x", [1, 1, 1], transform=tr.log, default_transform=None)

for jacobian_val in (True, False):
with pytest.raises(
Expand All @@ -645,7 +645,7 @@ def log_jac_det(self, value, *inputs):
buggy_transform = BuggyTransform()

with pm.Model() as m:
pm.Uniform("x", shape=(4, 3), transform=buggy_transform)
pm.Uniform("x", shape=(4, 3), transform=buggy_transform, default_transform=None)

for jacobian_val in (True, False):
with pytest.raises(
Expand Down
12 changes: 8 additions & 4 deletions tests/logprob/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,15 @@ def test_interdependent_transformed_rvs(self, reversed):
transform = pm.distributions.transforms.Interval(
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
)
x = pm.Uniform("x", lower=0, upper=1, transform=transform)
x = pm.Uniform("x", lower=0, upper=1, transform=transform, default_transform=None)
# Operation between the variables provides a regression test for #7054
y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform)
z = pm.Uniform("z", lower=0, upper=y, transform=transform)
w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform)
y = pm.Uniform(
"y", lower=0, upper=pt.exp(x), transform=transform, default_transform=None
)
z = pm.Uniform("z", lower=0, upper=y, transform=transform, default_transform=None)
w = pm.Uniform(
"w", lower=0, upper=pt.square(z), transform=transform, default_transform=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass transform to default_transform

)

rvs = [x, y, z, w]
if reversed:
Expand Down
48 changes: 42 additions & 6 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.distributions import Normal, transforms
from pymc.distributions.distribution import PartialObservedRV
from pymc.distributions.transforms import log, simplex
from pymc.distributions.transforms import Transform, log, simplex
from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.transforms import IntervalTransform
Expand Down Expand Up @@ -527,6 +527,42 @@ def test_model_var_maps():
assert model.rvs_to_transforms[x] is None


class TestTransformArgs:
def test_transform_warning(self):
with pm.Model():
with pytest.warns(
UserWarning,
match="To disable default transform,"
" please use default_transform=None"
" instead of transform=None. Setting transform to"
" None will not have any effect in future.",
):
a = pm.Normal("a", transform=None)

def test_transform_order(self):
transform_order = []

class DummyTransform(Transform):
name = "dummy1"
ndim_supp = 0

def __init__(self, marker) -> None:
super().__init__()
self.marker = marker

def forward(self, value, *inputs):
nonlocal transform_order
transform_order.append(self.marker)
return value

def backward(self, value, *inputs):
return value

with pm.Model() as model:
x = pm.Normal("x", transform=DummyTransform(2), default_transform=DummyTransform(1))
assert transform_order == [1, 2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use regular transforms, and simply assert the obtained transform is a Chain, which as a property like transform_list that includes the transforms, and you can assert those are also the expected ones. The transform is available in models.rvs_to_transforms[x]

Also, would be nice to include a numerical example that would have led to nan or -inf probability before the change, like an ordered mixture of LogNormals evaluated at -1



def test_make_obs_var():
"""
Check returned values for `data` given known inputs to `as_tensor()`.
Expand All @@ -549,26 +585,26 @@ def test_make_obs_var():

# The function requires data and RV dimensionality to be compatible
with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."):
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None)
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None, None)

# Check function behavior using the various inputs
# dense, sparse: Ensure that the missing values are appropriately set to None
# masked: a deterministic variable is returned

dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None)
dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None, None)
assert dense_output == fake_distribution
assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant)
del fake_model.named_vars[fake_distribution.name]

sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None)
sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None, None)
assert sparse_output == fake_distribution
assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output])
del fake_model.named_vars[fake_distribution.name]

# Here the RandomVariable is split into observed/imputed and a Deterministic is returned
with pytest.warns(ImputationWarning):
masked_output = fake_model.make_obs_var(
fake_distribution, masked_array_input, None, None, None
fake_distribution, masked_array_input, None, None, None, None
)
assert masked_output != fake_distribution
assert not isinstance(masked_output, RandomVariable)
Expand All @@ -581,7 +617,7 @@ def test_make_obs_var():

# Test that setting total_size returns a MinibatchRandomVariable
scaled_outputs = fake_model.make_obs_var(
fake_distribution, dense_input, None, None, total_size=100
fake_distribution, dense_input, None, None, None, total_size=100
)
assert scaled_outputs != fake_distribution
assert isinstance(scaled_outputs.owner.op, MinibatchRandomVariable)
Expand Down
4 changes: 2 additions & 2 deletions tests/model/transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ def test_change_value_transforms_error():

def test_remove_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", transform=logodds)
q = pm.Uniform("q", transform=logodds)
p = pm.Uniform("p", transform=logodds, default_transform=None)
q = pm.Uniform("q", transform=logodds, default_transform=None)

new_m = remove_value_transforms(base_m)
new_p = new_m["p"]
Expand Down
2 changes: 1 addition & 1 deletion tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def test_transform_with_rv_dependency(self, symbolic_rv):
transform = pm.distributions.transforms.Interval(
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
)
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
y = pm.Uniform("y", lower=0, upper=x, transform=transform, default_transform=None)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336)
Expand Down