diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 6f48673dc9..3af0bb110e 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -333,6 +333,7 @@ def __new__( observed=None, total_size=None, transform=UNSET, + default_transform=UNSET, **kwargs, ) -> TensorVariable: """Adds a tensor variable corresponding to a PyMC distribution to the current model. @@ -414,10 +415,11 @@ def __new__( rv_out = model.register_rv( rv_out, name, - observed, - total_size, + observed=observed, + total_size=total_size, dims=dims, transform=transform, + default_transform=default_transform, initval=initval, ) diff --git a/pymc/model/core.py b/pymc/model/core.py index eee8f2a904..2f73c9ee24 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -22,7 +22,6 @@ from sys import modules from typing import ( TYPE_CHECKING, - Any, Literal, Optional, TypeVar, @@ -48,7 +47,7 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, is_minibatch -from pymc.distributions.transforms import _default_transform +from pymc.distributions.transforms import ChainedTransform, _default_transform from pymc.exceptions import ( BlockModelAccessError, ImputationWarning, @@ -58,6 +57,7 @@ ) from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import transformed_conditional_logp +from pymc.logprob.transforms import Transform from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values from pymc.model_graph import model_to_graphviz from pymc.pytensorf import ( @@ -1214,7 +1214,16 @@ def set_data( shared_object.set_value(values) def register_rv( - self, rv_var, name, observed=None, total_size=None, dims=None, transform=UNSET, initval=None + self, + rv_var, + name, + *, + observed=None, + total_size=None, + dims=None, + default_transform=UNSET, + transform=UNSET, + initval=None, ): """Register an (un)observed random variable with the model. @@ -1229,8 +1238,10 @@ def register_rv( upscales logp of variable with ``coef = total_size/var.shape[0]`` dims : tuple Dimension names for the variable. + default_transform + A default transform for the random variable in log-likelihood space. transform - A transform for the random variable in log-likelihood space. + Additional transform which may be applied after default transform. initval The initial value of the random variable. @@ -1255,7 +1266,7 @@ def register_rv( if total_size is not None: raise ValueError("total_size can only be passed to observed RVs") self.free_RVs.append(rv_var) - self.create_value_var(rv_var, transform) + self.create_value_var(rv_var, transform=transform, default_transform=default_transform) self.add_named_variable(rv_var, dims) self.set_initval(rv_var, initval) else: @@ -1278,7 +1289,9 @@ def register_rv( # `rv_var` is potentially changed by `make_obs_var`, # for example into a new graph for imputation of missing data. - rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size) + rv_var = self.make_obs_var( + rv_var, observed, dims, default_transform, transform, total_size + ) return rv_var @@ -1287,7 +1300,8 @@ def make_obs_var( rv_var: TensorVariable, data: np.ndarray, dims, - transform: Any | None, + default_transform: Transform | None, + transform: Transform | None, total_size: int | None, ) -> TensorVariable: """Create a `TensorVariable` for an observed random variable. @@ -1301,8 +1315,10 @@ def make_obs_var( The observed data. dims : tuple Dimension names for the variable. - transform : int, optional + default_transform A transform for the random variable in log-likelihood space. + transform + Additional transform which may be applied after default transform. Returns ------- @@ -1339,12 +1355,19 @@ def make_obs_var( # Register ObservedRV corresponding to observed component observed_rv.name = f"{name}_observed" - self.create_value_var(observed_rv, transform=None, value_var=observed_data) + self.create_value_var( + observed_rv, transform=transform, default_transform=None, value_var=observed_data + ) self.add_named_variable(observed_rv) self.observed_RVs.append(observed_rv) # Register FreeRV corresponding to unobserved components - self.register_rv(unobserved_rv, f"{name}_unobserved", transform=transform) + self.register_rv( + unobserved_rv, + f"{name}_unobserved", + transform=transform, + default_transform=default_transform, + ) # Register Deterministic that combines observed and missing # Note: This can widely increase memory consumption during sampling for large datasets @@ -1363,14 +1386,21 @@ def make_obs_var( rv_var.name = name rv_var.tag.observations = data - self.create_value_var(rv_var, transform=None, value_var=data) + self.create_value_var( + rv_var, transform=transform, default_transform=None, value_var=data + ) self.add_named_variable(rv_var, dims) self.observed_RVs.append(rv_var) return rv_var def create_value_var( - self, rv_var: TensorVariable, transform: Any, value_var: Variable | None = None + self, + rv_var: TensorVariable, + *, + default_transform: Transform, + transform: Transform, + value_var: Variable | None = None, ) -> TensorVariable: """Create a ``TensorVariable`` that will be used as the random variable's "value" in log-likelihood graphs. @@ -1385,7 +1415,11 @@ def create_value_var( ---------- rv_var : TensorVariable - transform : Any + default_transform: Transform + A transform for the random variable in log-likelihood space. + + transform: Transform + Additional transform which may be applied after default transform. value_var : Variable, optional @@ -1396,11 +1430,25 @@ def create_value_var( # Make the value variable a transformed value variable, # if there's an applicable transform - if transform is UNSET: + if transform is None and default_transform is UNSET: + default_transform = None + warnings.warn( + "To disable default transform, please use default_transform=None" + " instead of transform=None. Setting transform to None will" + " not have any effect in future.", + UserWarning, + ) + + if default_transform is UNSET: if rv_var.owner is None: - transform = None + default_transform = None else: - transform = _default_transform(rv_var.owner.op, rv_var) + default_transform = _default_transform(rv_var.owner.op, rv_var) + + if transform is UNSET: + transform = default_transform + elif transform is not None and default_transform is not None: + transform = ChainedTransform([default_transform, transform]) if value_var is None: if transform is None: diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index 05d8fe4200..7a3fdd9829 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -320,12 +320,14 @@ def first_non_model_var(var): var, value, *dims = model_var.owner.inputs transform = model_var.owner.op.transform model.free_RVs.append(var) - model.create_value_var(var, transform=transform, value_var=value) + model.create_value_var( + var, transform=transform, default_transform=None, value_var=value + ) model.set_initval(var, initval=None) elif isinstance(model_var.owner.op, ModelObservedRV): var, value, *dims = model_var.owner.inputs model.observed_RVs.append(var) - model.create_value_var(var, transform=None, value_var=value) + model.create_value_var(var, transform=None, default_transform=None, value_var=value) elif isinstance(model_var.owner.op, ModelPotential): var, *dims = model_var.owner.inputs model.potentials.append(var) diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index a7293c0ed3..e328c96931 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -1359,17 +1359,17 @@ def test_warning(self): with warnings.catch_warnings(): warnings.simplefilter("error") - Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None) + Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, default_transform=None) with warnings.catch_warnings(): warnings.simplefilter("error") - Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, observed=1) + Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists, observed=1) # Case where the appropriate default transform is None comp_dists = [Normal.dist(), Normal.dist()] with warnings.catch_warnings(): warnings.simplefilter("error") - Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists) + Mixture("mix7", w=[0.5, 0.5], comp_dists=comp_dists) class TestZeroInflatedMixture: diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index b0187a4ebe..8d464f206a 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -619,7 +619,7 @@ def test_transform_univariate_dist_logp_shape(): def test_univariate_transform_multivariate_dist_raises(): with pm.Model() as m: - pm.Dirichlet("x", [1, 1, 1], transform=tr.log) + pm.Dirichlet("x", [1, 1, 1], default_transform=tr.log) for jacobian_val in (True, False): with pytest.raises( @@ -645,7 +645,7 @@ def log_jac_det(self, value, *inputs): buggy_transform = BuggyTransform() with pm.Model() as m: - pm.Uniform("x", shape=(4, 3), transform=buggy_transform) + pm.Uniform("x", shape=(4, 3), default_transform=buggy_transform) for jacobian_val in (True, False): with pytest.raises( diff --git a/tests/logprob/test_utils.py b/tests/logprob/test_utils.py index c59b332495..47cb65f195 100644 --- a/tests/logprob/test_utils.py +++ b/tests/logprob/test_utils.py @@ -218,11 +218,11 @@ def test_interdependent_transformed_rvs(self, reversed): transform = pm.distributions.transforms.Interval( bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) ) - x = pm.Uniform("x", lower=0, upper=1, transform=transform) + x = pm.Uniform("x", lower=0, upper=1, default_transform=transform) # Operation between the variables provides a regression test for #7054 - y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform) - z = pm.Uniform("z", lower=0, upper=y, transform=transform) - w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform) + y = pm.Uniform("y", lower=0, upper=pt.exp(x), default_transform=transform) + z = pm.Uniform("z", lower=0, upper=y, default_transform=transform) + w = pm.Uniform("w", lower=0, upper=pt.square(z), default_transform=transform) rvs = [x, y, z, w] if reversed: diff --git a/tests/model/test_core.py b/tests/model/test_core.py index d62890ab1d..bb68d07886 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -42,7 +42,14 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.distributions import Normal, transforms from pymc.distributions.distribution import PartialObservedRV -from pymc.distributions.transforms import log, simplex +from pymc.distributions.transforms import ( + ChainedTransform, + Interval, + LogTransform, + log, + ordered, + simplex, +) from pymc.exceptions import ImputationWarning, ShapeError, ShapeWarning from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import IntervalTransform @@ -527,6 +534,35 @@ def test_model_var_maps(): assert model.rvs_to_transforms[x] is None +class TestTransformArgs: + def test_transform_warning(self): + with pm.Model(): + with pytest.warns( + UserWarning, + match="To disable default transform," + " please use default_transform=None" + " instead of transform=None. Setting transform to" + " None will not have any effect in future.", + ): + a = pm.Normal("a", transform=None) + + def test_transform_order(self): + with pm.Model() as model: + x = pm.Normal("x", transform=Interval(0, 1), default_transform=log) + transform = model.rvs_to_transforms[x] + assert isinstance(transform, ChainedTransform) + assert isinstance(transform.transform_list[0], LogTransform) + assert isinstance(transform.transform_list[1], Interval) + + def test_default_transform_is_applied(self): + with pm.Model() as model1: + x1 = pm.LogNormal("x1", [0, 0], [1, 1], transform=ordered, default_transform=None) + with pm.Model() as model2: + x2 = pm.LogNormal("x2", [0, 0], [1, 1], transform=ordered) + assert np.isinf(model1.compile_logp()({"x1_ordered__": (-1, -1)})) + assert np.isfinite(model2.compile_logp()({"x2_chain__": (-1, -1)})) + + def test_make_obs_var(): """ Check returned values for `data` given known inputs to `as_tensor()`. @@ -549,18 +585,18 @@ def test_make_obs_var(): # The function requires data and RV dimensionality to be compatible with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."): - fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None) + fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None, None) # Check function behavior using the various inputs # dense, sparse: Ensure that the missing values are appropriately set to None # masked: a deterministic variable is returned - dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None) + dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None, None) assert dense_output == fake_distribution assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant) del fake_model.named_vars[fake_distribution.name] - sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None) + sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None, None) assert sparse_output == fake_distribution assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output]) del fake_model.named_vars[fake_distribution.name] @@ -568,7 +604,7 @@ def test_make_obs_var(): # Here the RandomVariable is split into observed/imputed and a Deterministic is returned with pytest.warns(ImputationWarning): masked_output = fake_model.make_obs_var( - fake_distribution, masked_array_input, None, None, None + fake_distribution, masked_array_input, None, None, None, None ) assert masked_output != fake_distribution assert not isinstance(masked_output, RandomVariable) @@ -581,7 +617,7 @@ def test_make_obs_var(): # Test that setting total_size returns a MinibatchRandomVariable scaled_outputs = fake_model.make_obs_var( - fake_distribution, dense_input, None, None, total_size=100 + fake_distribution, dense_input, None, None, None, total_size=100 ) assert scaled_outputs != fake_distribution assert isinstance(scaled_outputs.owner.op, MinibatchRandomVariable) diff --git a/tests/model/transform/test_conditioning.py b/tests/model/transform/test_conditioning.py index a9a8ab712d..25f4693052 100644 --- a/tests/model/transform/test_conditioning.py +++ b/tests/model/transform/test_conditioning.py @@ -286,8 +286,8 @@ def test_change_value_transforms_error(): def test_remove_value_transforms(): with pm.Model() as base_m: - p = pm.Uniform("p", transform=logodds) - q = pm.Uniform("q", transform=logodds) + p = pm.Uniform("p", transform=logodds, default_transform=None) + q = pm.Uniform("q", transform=logodds, default_transform=None) new_m = remove_value_transforms(base_m) new_p = new_m["p"] diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 3f676d0846..31b48250fc 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -303,7 +303,7 @@ def test_transform_with_rv_dependency(self, symbolic_rv): transform = pm.distributions.transforms.Interval( bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) ) - y = pm.Uniform("y", lower=0, upper=x, transform=transform) + y = pm.Uniform("y", lower=0, upper=x, transform=transform, default_transform=None) with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336)