diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 2b828d5c6d..fca5f6a984 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -27,16 +27,17 @@ import pymc as pm from pymc.logprob.transforms import ( + ChainedTransform, CircularTransform, IntervalTransform, LogOddsTransform, LogTransform, - RVTransform, SimplexTransform, + Transform, ) __all__ = [ - "RVTransform", + "Transform", "simplex", "logodds", "Interval", @@ -60,6 +61,10 @@ def __getattr__(name): warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning) return sum_to_1 + if name == "RVTransform": + warnings.warn("RVTransform has been renamed to Transform", FutureWarning) + return Transform + raise AttributeError(f"module {__name__} has no attribute {name}") @@ -69,7 +74,7 @@ def _default_transform(op: Op, rv: TensorVariable): return None -class LogExpM1(RVTransform): +class LogExpM1(Transform): name = "log_exp_m1" def backward(self, value, *inputs): @@ -87,7 +92,7 @@ def log_jac_det(self, value, *inputs): return -pt.softplus(-value) -class Ordered(RVTransform): +class Ordered(Transform): name = "ordered" def __init__(self, ndim_supp=None): @@ -110,7 +115,7 @@ def log_jac_det(self, value, *inputs): return pt.sum(value[..., 1:], axis=-1) -class SumTo1(RVTransform): +class SumTo1(Transform): """ Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1] This Transformation operates on the last dimension of the input tensor. @@ -134,7 +139,7 @@ def log_jac_det(self, value, *inputs): return pt.sum(y, axis=-1) -class CholeskyCovPacked(RVTransform): +class CholeskyCovPacked(Transform): """ Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale @@ -162,45 +167,7 @@ def log_jac_det(self, value, *inputs): return pt.sum(value[..., self.diag_idxs], axis=-1) -class Chain(RVTransform): - __slots__ = ("param_extract_fn", "transform_list", "name") - - def __init__(self, transform_list): - self.transform_list = transform_list - self.name = "+".join([transf.name for transf in self.transform_list]) - - def forward(self, value, *inputs): - y = value - for transf in self.transform_list: - # TODO:Needs proper discussion as to what should be - # passed as inputs here - y = transf.forward(y, *inputs) - return y - - def backward(self, value, *inputs): - x = value - for transf in reversed(self.transform_list): - x = transf.backward(x, *inputs) - return x - - def log_jac_det(self, value, *inputs): - y = pt.as_tensor_variable(value) - det_list = [] - ndim0 = y.ndim - for transf in reversed(self.transform_list): - det_ = transf.log_jac_det(y, *inputs) - det_list.append(det_) - y = transf.backward(y, *inputs) - ndim0 = min(ndim0, det_.ndim) - # match the shape of the smallest log_jac_det - det = 0.0 - for det_ in det_list: - if det_.ndim > ndim0: - det += det_.sum(axis=-1) - else: - det += det_ - return det - +Chain = ChainedTransform simplex = SimplexTransform() simplex.__doc__ = """ @@ -297,7 +264,7 @@ def bounds_fn(*rv_inputs): super().__init__(args_fn=bounds_fn) -class ZeroSumTransform(RVTransform): +class ZeroSumTransform(Transform): """ Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``. @@ -314,43 +281,43 @@ class ZeroSumTransform(RVTransform): def __init__(self, zerosum_axes): self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes) + @staticmethod + def extend_axis(array, axis): + n = pm.floatX(array.shape[axis] + 1) + sum_vals = array.sum(axis, keepdims=True) + norm = sum_vals / (pt.sqrt(n) + n) + fill_val = norm - sum_vals / pt.sqrt(n) + + out = pt.concatenate([array, fill_val], axis=axis) + return out - norm + + @staticmethod + def extend_axis_rev(array, axis): + normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] + + n = pm.floatX(array.shape[normalized_axis]) + last = pt.take(array, [-1], axis=normalized_axis) + + sum_vals = -last * pt.sqrt(n) + norm = sum_vals / (pt.sqrt(n) + n) + slice_before = (slice(None, None),) * normalized_axis + + return array[slice_before + (slice(None, -1),)] + norm + def forward(self, value, *rv_inputs): for axis in self.zerosum_axes: - value = extend_axis_rev(value, axis=axis) + value = self.extend_axis_rev(value, axis=axis) return value def backward(self, value, *rv_inputs): for axis in self.zerosum_axes: - value = extend_axis(value, axis=axis) + value = self.extend_axis(value, axis=axis) return value def log_jac_det(self, value, *rv_inputs): return pt.constant(0.0) -def extend_axis(array, axis): - n = pm.floatX(array.shape[axis] + 1) - sum_vals = array.sum(axis, keepdims=True) - norm = sum_vals / (pt.sqrt(n) + n) - fill_val = norm - sum_vals / pt.sqrt(n) - - out = pt.concatenate([array, fill_val], axis=axis) - return out - norm - - -def extend_axis_rev(array, axis): - normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] - - n = pm.floatX(array.shape[normalized_axis]) - last = pt.take(array, [-1], axis=normalized_axis) - - sum_vals = -last * pt.sqrt(n) - norm = sum_vals / (pt.sqrt(n) + n) - slice_before = (slice(None, None),) * normalized_axis - - return array[slice_before + (slice(None, -1),)] + norm - - log_exp_m1 = LogExpM1() log_exp_m1.__doc__ = """ Instantiation of :class:`pymc.distributions.transforms.LogExpM1` diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index d711d25a4d..268b678505 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -48,6 +48,8 @@ "Kron", ] +from pymc.pytensorf import constant_fold + TensorLike = Union[np.ndarray, TensorVariable] IntSequence = Union[np.ndarray, Sequence[int]] @@ -183,9 +185,6 @@ def n_dims(self) -> int: def _slice(self, X, Xs=None): xdims = X.shape[-1] if isinstance(xdims, Variable): - # Circular dependency - from pymc.pytensorf import constant_fold - [xdims] = constant_fold([xdims]) if self.input_dim != xdims: warnings.warn( diff --git a/pymc/gp/util.py b/pymc/gp/util.py index 0e683a5d38..876f26b8cc 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -18,13 +18,14 @@ import pytensor.tensor as pt from pytensor.compile import SharedVariable +from pytensor.graph import ancestors from pytensor.tensor.variable import TensorConstant from scipy.cluster.vq import kmeans # Avoid circular dependency when importing modelcontext from pymc.distributions.distribution import Distribution from pymc.model import modelcontext -from pymc.pytensorf import compile_pymc, walk_model +from pymc.pytensorf import compile_pymc _ = Distribution # keep both pylint and black happy @@ -48,7 +49,7 @@ def replace_with_values(vars_needed, replacements=None, model=None): model = modelcontext(model) inputs, input_names = [], [] - for rv in walk_model(vars_needed): + for rv in ancestors(vars_needed): if rv in model.named_vars.values() and not isinstance(rv, SharedVariable): inputs.append(rv) input_names.append(rv.name) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index b4248e7eda..ddcbb01138 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -24,7 +24,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.tensor.variable import TensorVariable -from pymc.logprob.transforms import RVTransform +from pymc.logprob.transforms import Transform from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name @@ -177,7 +177,7 @@ def inner(seed, *args, **kwargs): def make_initial_point_expression( *, free_rvs: Sequence[TensorVariable], - rvs_to_transforms: Dict[TensorVariable, RVTransform], + rvs_to_transforms: Dict[TensorVariable, Transform], initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], jitter_rvs: Set[TensorVariable] = None, default_strategy: str = "moment", diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 641cf6d0a3..36f7634539 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -64,8 +64,9 @@ ) from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph from pymc.logprob.transform_value import TransformValuesRewrite -from pymc.logprob.transforms import RVTransform -from pymc.logprob.utils import find_rvs_in_graph, rvs_to_value_vars +from pymc.logprob.transforms import Transform +from pymc.logprob.utils import rvs_in_graph +from pymc.pytensorf import replace_vars_in_graphs TensorLike: TypeAlias = Union[Variable, float, np.ndarray] @@ -76,7 +77,7 @@ def _find_unallowed_rvs_in_graph(graph): return { rv - for rv in find_rvs_in_graph(graph) + for rv in rvs_in_graph(graph) if not isinstance(rv.owner.op, (SimulatorRV, MinibatchIndexRV)) } @@ -530,11 +531,9 @@ def conditional_logp( continue # Replace `RandomVariable`s in the inputs with value variables. - # Also, store the results in the `replacements` map for the nodes - # that follow. - remapped_vars, _ = rvs_to_value_vars( - q_values + list(node.inputs), - initial_replacements=replacements, + remapped_vars = replace_vars_in_graphs( + graphs=q_values + list(node.inputs), + replacements=replacements, ) q_values = remapped_vars[: len(q_values)] q_rv_inputs = remapped_vars[len(q_values) :] @@ -562,8 +561,7 @@ def conditional_logp( logprob_vars[q_value_var] = q_logprob_var - # Recompute test values for the changes introduced by the - # replacements above. + # Recompute test values for the changes introduced by the replacements above. if config.compute_test_value != "off": for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars): compute_test_value(node) @@ -589,7 +587,7 @@ def transformed_conditional_logp( rvs: Sequence[TensorVariable], *, rvs_to_values: Dict[TensorVariable, TensorVariable], - rvs_to_transforms: Dict[TensorVariable, RVTransform], + rvs_to_transforms: Dict[TensorVariable, Transform], jacobian: bool = True, **kwargs, ) -> List[TensorVariable]: diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 1049fd7bb7..fb5e672421 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -44,6 +44,7 @@ from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db +from pymc.logprob.utils import replace_rvs_by_values class MeasurableSpecifyShape(SpecifyShape): @@ -107,8 +108,6 @@ class MeasurableCheckAndRaise(CheckAndRaise): @_logprob.register(MeasurableCheckAndRaise) def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): - from pymc.pytensorf import replace_rvs_by_values - (value,) = values # transfer assertion from rv to value assertions = replace_rvs_by_values(assertions, rvs_to_values={inner_rv: value}) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index b5c4947937..a77e5531a8 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -78,7 +78,8 @@ measurable_ir_rewrites_db, subtensor_ops, ) -from pymc.logprob.utils import check_potential_measurability +from pymc.logprob.utils import check_potential_measurability, replace_rvs_by_values +from pymc.pytensorf import constant_fold def is_newaxis(x): @@ -255,9 +256,6 @@ def get_stack_mixture_vars( mixture_rvs = joined_rvs.owner.inputs elif isinstance(joined_rvs.owner.op, Join): - # TODO: Find better solution to avoid this circular dependency - from pymc.pytensorf import constant_fold - join_axis = joined_rvs.owner.inputs[0] # TODO: Support symbolic join axes. This will raise ValueError if it's not a constant (join_axis,) = constant_fold((join_axis,), raise_not_constant=False) @@ -351,9 +349,6 @@ def logprob_MixtureRV( comp_rvs = [comp[None] for comp in comp_rvs] original_shape = (len(comp_rvs),) else: - # TODO: Find better solution to avoid this circular dependency - from pymc.pytensorf import constant_fold - join_axis_val = constant_fold((join_axis,))[0].item() original_shape = shape_tuple(comp_rvs[0]) @@ -544,7 +539,6 @@ def find_measurable_ifelse_mixture(fgraph, node): @_logprob.register(MeasurableIfElse) def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs): """Compute the log-likelihood graph for an `IfElse`.""" - from pymc.pytensorf import replace_rvs_by_values assert len(values) * 2 == len(base_rvs) diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 283dbd1c3e..e0b4fe3032 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -62,7 +62,7 @@ logprob_rewrites_db, measurable_ir_rewrites_db, ) -from pymc.pytensorf import replace_rvs_by_values +from pymc.logprob.utils import replace_rvs_by_values class MeasurableScan(Scan): diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 1223b7091a..e5f6ebc5a9 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -55,7 +55,8 @@ assume_measured_ir_outputs, measurable_ir_rewrites_db, ) -from pymc.logprob.utils import check_potential_measurability +from pymc.logprob.utils import check_potential_measurability, replace_rvs_by_values +from pymc.pytensorf import constant_fold @node_rewriter([Alloc]) @@ -131,7 +132,6 @@ class MeasurableMakeVector(MakeVector): def logprob_make_vector(op, values, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `MeasurableMakeVector`.""" # TODO: Sort out this circular dependency issue - from pymc.pytensorf import replace_rvs_by_values (value,) = values @@ -158,9 +158,6 @@ class MeasurableJoin(Join): @_logprob.register(MeasurableJoin) def logprob_join(op, values, axis, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `Join`.""" - # TODO: Find better way to avoid circular dependency - from pymc.pytensorf import constant_fold, replace_rvs_by_values - (value,) = values base_rv_shapes = [base_var.shape[axis] for base_var in base_rvs] diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index 4c9b82ebc4..ee37fdca1d 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -27,7 +27,7 @@ from pymc.logprob.abstract import MeasurableVariable, _logprob from pymc.logprob.rewriting import PreserveRVMappings, cleanup_ir_rewrites_db -from pymc.logprob.transforms import RVTransform +from pymc.logprob.transforms import Transform class TransformedValue(Op): @@ -67,7 +67,7 @@ class TransformedValueRV(Op): __props__ = ("transforms",) - def __init__(self, transforms: Sequence[RVTransform]): + def __init__(self, transforms: Sequence[Transform]): self.transforms = tuple(transforms) super().__init__() @@ -320,7 +320,7 @@ class TransformValuesRewrite(GraphRewriter): def __init__( self, - values_to_transforms: Dict[TensorVariable, Union[RVTransform, None]], + values_to_transforms: Dict[TensorVariable, Union[Transform, None]], ): """ Parameters diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 2ddd9d9e5b..ed4af12565 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -120,7 +120,7 @@ from pymc.logprob.utils import CheckParameterValue, check_potential_measurability -class RVTransform(abc.ABC): +class Transform(abc.ABC): ndim_supp = None @abc.abstractmethod @@ -174,10 +174,10 @@ class MeasurableTransform(MeasurableElemwise): # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` - transform_elemwise: RVTransform + transform_elemwise: Transform measurable_input_idx: int - def __init__(self, *args, transform: RVTransform, measurable_input_idx: int, **kwargs): + def __init__(self, *args, transform: Transform, measurable_input_idx: int, **kwargs): self.transform_elemwise = transform self.measurable_input_idx = measurable_input_idx super().__init__(*args, **kwargs) @@ -444,7 +444,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li scalar_op = node.op.scalar_op measurable_input_idx = 0 transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,) - transform: RVTransform + transform: Transform transform_dict = { Exp: ExpTransform(), @@ -559,7 +559,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li ) -class SinhTransform(RVTransform): +class SinhTransform(Transform): name = "sinh" ndim_supp = 0 @@ -570,7 +570,7 @@ def backward(self, value, *inputs): return pt.arcsinh(value) -class CoshTransform(RVTransform): +class CoshTransform(Transform): name = "cosh" ndim_supp = 0 @@ -589,7 +589,7 @@ def log_jac_det(self, value, *inputs): ) -class TanhTransform(RVTransform): +class TanhTransform(Transform): name = "tanh" ndim_supp = 0 @@ -600,7 +600,7 @@ def backward(self, value, *inputs): return pt.arctanh(value) -class ArcsinhTransform(RVTransform): +class ArcsinhTransform(Transform): name = "arcsinh" ndim_supp = 0 @@ -611,7 +611,7 @@ def backward(self, value, *inputs): return pt.sinh(value) -class ArccoshTransform(RVTransform): +class ArccoshTransform(Transform): name = "arccosh" ndim_supp = 0 @@ -622,7 +622,7 @@ def backward(self, value, *inputs): return pt.cosh(value) -class ArctanhTransform(RVTransform): +class ArctanhTransform(Transform): name = "arctanh" ndim_supp = 0 @@ -633,7 +633,7 @@ def backward(self, value, *inputs): return pt.tanh(value) -class ErfTransform(RVTransform): +class ErfTransform(Transform): name = "erf" ndim_supp = 0 @@ -644,7 +644,7 @@ def backward(self, value, *inputs): return pt.erfinv(value) -class ErfcTransform(RVTransform): +class ErfcTransform(Transform): name = "erfc" ndim_supp = 0 @@ -655,7 +655,7 @@ def backward(self, value, *inputs): return pt.erfcinv(value) -class ErfcxTransform(RVTransform): +class ErfcxTransform(Transform): name = "erfcx" ndim_supp = 0 @@ -681,7 +681,7 @@ def calc_delta_x(value, prior_result): return result[-1] -class LocTransform(RVTransform): +class LocTransform(Transform): name = "loc" def __init__(self, transform_args_fn): @@ -699,7 +699,7 @@ def log_jac_det(self, value, *inputs): return pt.zeros_like(value) -class ScaleTransform(RVTransform): +class ScaleTransform(Transform): name = "scale" def __init__(self, transform_args_fn): @@ -718,7 +718,7 @@ def log_jac_det(self, value, *inputs): return -pt.log(pt.abs(pt.broadcast_to(scale, value.shape))) -class LogTransform(RVTransform): +class LogTransform(Transform): name = "log" def forward(self, value, *inputs): @@ -731,7 +731,7 @@ def log_jac_det(self, value, *inputs): return value -class ExpTransform(RVTransform): +class ExpTransform(Transform): name = "exp" def forward(self, value, *inputs): @@ -744,7 +744,7 @@ def log_jac_det(self, value, *inputs): return -pt.log(value) -class AbsTransform(RVTransform): +class AbsTransform(Transform): name = "abs" def forward(self, value, *inputs): @@ -758,7 +758,7 @@ def log_jac_det(self, value, *inputs): return pt.switch(value >= 0, 0, np.nan) -class PowerTransform(RVTransform): +class PowerTransform(Transform): name = "power" def __init__(self, power=None): @@ -801,7 +801,7 @@ def log_jac_det(self, value, *inputs): return res -class IntervalTransform(RVTransform): +class IntervalTransform(Transform): name = "interval" def __init__(self, args_fn: Callable[..., Tuple[Optional[Variable], Optional[Variable]]]): @@ -909,7 +909,7 @@ def log_jac_det(self, value, *inputs): return pt.zeros_like(value) -class LogOddsTransform(RVTransform): +class LogOddsTransform(Transform): name = "logodds" def backward(self, value, *inputs): @@ -923,7 +923,7 @@ def log_jac_det(self, value, *inputs): return pt.log(sigmoid_value) + pt.log1p(-sigmoid_value) -class SimplexTransform(RVTransform): +class SimplexTransform(Transform): name = "simplex" def forward(self, value, *inputs): @@ -950,7 +950,7 @@ def log_jac_det(self, value, *inputs): return pt.sum(res, -1) -class CircularTransform(RVTransform): +class CircularTransform(Transform): name = "circular" def backward(self, value, *inputs): @@ -963,12 +963,11 @@ def log_jac_det(self, value, *inputs): return pt.zeros(value.shape) -class ChainedTransform(RVTransform): +class ChainedTransform(Transform): name = "chain" - def __init__(self, transform_list, base_op): + def __init__(self, transform_list): self.transform_list = transform_list - self.base_op = base_op def forward(self, value, *inputs): for transform in self.transform_list: diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 783b9ad95d..fd256bf1a0 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -33,30 +33,18 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - +import typing import warnings -from typing import ( - Callable, - Container, - Dict, - Generator, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, - Union, -) +from typing import Container, Dict, List, Optional, Sequence, Set, Tuple, Union import numpy as np +import pytensor from pytensor import Variable from pytensor import tensor as pt -from pytensor.graph import Apply, Op -from pytensor.graph.basic import Constant, clone_get_equiv, graph_inputs, walk -from pytensor.graph.fg import FunctionGraph +from pytensor.graph import Apply, Op, node_rewriter +from pytensor.graph.basic import walk from pytensor.graph.op import HasInnerGraph from pytensor.link.c.type import CType from pytensor.raise_op import CheckAndRaise @@ -64,130 +52,87 @@ from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.pytensorf import replace_vars_in_graphs from pymc.util import makeiter +if typing.TYPE_CHECKING: + from pymc.logprob.transforms import Transform -def walk_model( - graphs: Iterable[TensorVariable], - walk_past_rvs: bool = False, - stop_at_vars: Optional[Set[TensorVariable]] = None, - expand_fn: Callable[[TensorVariable], List[TensorVariable]] = lambda var: [], -) -> Generator[TensorVariable, None, None]: - """Walk model graphs and yield their nodes. - By default, these walks will not go past ``MeasurableVariable`` nodes. +def replace_rvs_by_values( + graphs: Sequence[TensorVariable], + *, + rvs_to_values: Dict[TensorVariable, TensorVariable], + rvs_to_transforms: Optional[Dict[TensorVariable, "Transform"]] = None, +) -> List[TensorVariable]: + """Clone and replace random variables in graphs with their value variables. Parameters ---------- graphs - The graphs to walk. - walk_past_rvs - If ``True``, the walk will not terminate at ``MeasurableVariable``s. - stop_at_vars - A list of variables at which the walk will terminate. - expand_fn - A function that returns the next variable(s) to be traversed. + The graphs in which to perform the replacements. + rvs_to_values + Mapping between the original graph RVs and respective value variables + rvs_to_transforms, optional + Mapping between the original graph RVs and respective value transforms """ - if stop_at_vars is None: - stop_at_vars = set() - - def expand(var: TensorVariable, stop_at_vars=stop_at_vars) -> List[TensorVariable]: - new_vars = expand_fn(var) - - if ( - var.owner - and (walk_past_rvs or not isinstance(var.owner.op, MeasurableVariable)) - and (var not in stop_at_vars) - ): - new_vars.extend(reversed(var.owner.inputs)) - - return new_vars - - yield from walk(graphs, expand, False) - - -def replace_rvs_in_graphs( - graphs: Iterable[TensorVariable], - replacement_fn: Callable[ - [TensorVariable, Dict[TensorVariable, TensorVariable]], - Dict[TensorVariable, TensorVariable], - ], - initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None, - **kwargs, -) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]: - """Replace random variables in graphs. - This will *not* recompute test values. - - Parameters - ---------- - graphs - The graphs in which random variables are to be replaced. - - Returns - ------- - A ``tuple`` containing the transformed graphs and a ``dict`` of the - replacements that were made. - """ replacements = {} - if initial_replacements: - replacements.update(initial_replacements) - - def expand_replace(var: TensorVariable) -> List[TensorVariable]: - new_nodes: List[TensorVariable] = [] - if var.owner and isinstance(var.owner.op, MeasurableVariable): - new_nodes.extend(replacement_fn(var, replacements)) - return new_nodes - for var in walk_model(graphs, expand_fn=expand_replace, **kwargs): + def populate_replacements(var): + # Populate replacements dict with {rv: value} pairs indicating which graph + # RVs should be replaced by what value variables. + if not var.owner: + return [] + + next_vars = [] + value = rvs_to_values.get(var, None) + if value is not None: + rv = var + + if rvs_to_transforms is not None: + transform = rvs_to_transforms.get(rv, None) + if transform is not None: + # We want to replace uses of the RV by the back-transformation of its value + value = transform.backward(value, *rv.owner.inputs) + # The value may have a less precise type than the rv. In this case + # filter_variable will add a SpecifyShape to ensure they are consistent + value = rv.type.filter_variable(value, allow_convert=True) + value.name = rv.name + + replacements[rv] = value + # Also walk the graph of the value variable to make any additional + # replacements if that is not a simple input variable + next_vars.append(value) + + next_vars.extend(reversed(var.owner.inputs)) + return next_vars + + # Iterate over the generator to populate the replacements + for _ in walk(graphs, populate_replacements, bfs=False): pass - if replacements: - inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] - equiv = {k: k for k in replacements.keys()} - equiv = clone_get_equiv(inputs, graphs, False, False, equiv) - - fg = FunctionGraph( - [equiv[i] for i in inputs], - [equiv[o] for o in graphs], - clone=False, - ) - - fg.replace_all(replacements.items(), import_missing=True) + return replace_vars_in_graphs(graphs, replacements) - graphs = list(fg.outputs) - return graphs, replacements - - -def rvs_to_value_vars( - graphs: Iterable[TensorVariable], - initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None, - **kwargs, -) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]: - """Replace random variables in graphs with their value variables. - - This will *not* recompute test values in the resulting graphs. +def rvs_in_graph(vars: Union[Variable, Sequence[Variable]]) -> Set[Variable]: + """Assert that there are no `MeasurableVariable` nodes in a graph.""" - Parameters - ---------- - graphs - The graphs in which to perform the replacements. - initial_replacements - A ``dict`` containing the initial replacements to be made. + def expand(r): + owner = r.owner + if owner: + inputs = list(reversed(owner.inputs)) - """ + if isinstance(owner.op, HasInnerGraph): + inputs += owner.op.inner_outputs - def replace_fn(var, replacements): - rv_value_var = replacements.get(var, None) - if rv_value_var is not None: - replacements[var] = rv_value_var - # In case the value variable is itself a graph, we walk it for - # potential replacements - return [rv_value_var] - return [] + return inputs - return replace_rvs_in_graphs(graphs, replace_fn, initial_replacements, **kwargs) + return { + node + for node in walk(makeiter(vars), expand, False) + if node.owner and isinstance(node.owner.op, (RandomVariable, MeasurableVariable)) + } def convert_indices(indices, entry): @@ -214,13 +159,18 @@ def indices_from_subtensor(idx_list, indices): def check_potential_measurability( inputs: Tuple[TensorVariable], valued_rvs: Container[TensorVariable] ) -> bool: + valued_rvs = set(valued_rvs) + + def expand_fn(var): + # expand_fn does not go beyond valued_rvs or any MeasurableVariable + if var.owner and not isinstance(var.owner.op, MeasurableVariable) and var not in valued_rvs: + return reversed(var.owner.inputs) + else: + return [] + if any( ancestor_var - for ancestor_var in walk_model( - inputs, - walk_past_rvs=False, - stop_at_vars=set(valued_rvs), - ) + for ancestor_var in walk(inputs, expand=expand_fn, bfs=False) if ( ancestor_var.owner and isinstance(ancestor_var.owner.op, MeasurableVariable) @@ -251,6 +201,48 @@ def __str__(self): return f"Check{{{self.msg}}}" +@node_rewriter(tracks=[CheckParameterValue]) +def local_remove_check_parameter(fgraph, node): + """Rewrite that removes CheckParameterValue + + This is used when compile_rv_inplace + """ + if isinstance(node.op, CheckParameterValue): + return [node.inputs[0]] + + +@node_rewriter(tracks=[CheckParameterValue]) +def local_check_parameter_to_ninf_switch(fgraph, node): + if not node.op.can_be_replaced_by_ninf: + return None + + logp_expr, *logp_conds = node.inputs + if len(logp_conds) > 1: + logp_cond = pt.all(logp_conds) + else: + (logp_cond,) = logp_conds + out = pt.switch(logp_cond, logp_expr, -np.inf) + out.name = node.op.msg + + if out.dtype != node.outputs[0].dtype: + out = pt.cast(out, node.outputs[0].dtype) + + return [out] + + +pytensor.compile.optdb["canonicalize"].register( + "local_remove_check_parameter", + local_remove_check_parameter, + use_db_name_as_tag=False, +) + +pytensor.compile.optdb["canonicalize"].register( + "local_check_parameter_to_ninf_switch", + local_check_parameter_to_ninf_switch, + use_db_name_as_tag=False, +) + + class DiracDelta(Op): """An `Op` that represents a Dirac-delta distribution.""" @@ -291,23 +283,3 @@ def diracdelta_logprob(op, values, *inputs, **kwargs): (const_value,) = inputs values, const_value = pt.broadcast_arrays(values, const_value) return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf) - - -def find_rvs_in_graph(vars: Union[Variable, Sequence[Variable]]) -> Set[Variable]: - """Assert that there are no `MeasurableVariable` nodes in a graph.""" - - def expand(r): - owner = r.owner - if owner: - inputs = list(reversed(owner.inputs)) - - if isinstance(owner.op, HasInnerGraph): - inputs += owner.op.inner_outputs - - return inputs - - return { - node - for node in walk(makeiter(vars), expand, False) - if node.owner and isinstance(node.owner.op, (RandomVariable, MeasurableVariable)) - } diff --git a/pymc/model/core.py b/pymc/model/core.py index 65ad468f75..4a0ff75733 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -65,7 +65,7 @@ ) from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import transformed_conditional_logp -from pymc.logprob.utils import ParameterValueError +from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values from pymc.model_graph import model_to_graphviz from pymc.pytensorf import ( PointFunc, @@ -75,7 +75,6 @@ gradient, hessian, inputvars, - replace_rvs_by_values, rewrite_pregrad, ) from pymc.util import ( diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index 9cfce57fcb..eac640cebd 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -24,7 +24,7 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.sharedvar import ScalarSharedVariable -from pymc.logprob.transforms import RVTransform +from pymc.logprob.transforms import Transform from pymc.model.core import Model from pymc.pytensorf import StringType, find_rng_nodes, toposort_replace @@ -59,8 +59,8 @@ def perform(self, *args, **kwargs): class ModelValuedVar(ModelVar): __props__ = ("transform",) - def __init__(self, transform: Optional[RVTransform] = None): - if transform is not None and not isinstance(transform, RVTransform): + def __init__(self, transform: Optional[Transform] = None): + if transform is not None and not isinstance(transform, Transform): raise TypeError(f"transform must be None or RVTransform type, got {type(transform)}") self.transform = transform super().__init__() diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py index faa31339f7..87b9828fcf 100644 --- a/pymc/model/transform/conditioning.py +++ b/pymc/model/transform/conditioning.py @@ -13,17 +13,14 @@ # limitations under the License. import warnings -from typing import Any, List, Mapping, Optional, Sequence, Union +from typing import Any, Mapping, Optional, Sequence, Union -from pytensor import Variable from pytensor.graph import ancestors -from pytensor.graph.basic import walk -from pytensor.graph.op import HasInnerGraph from pytensor.tensor import TensorVariable -from pytensor.tensor.random.op import RandomVariable from pymc import Model -from pymc.logprob.transforms import RVTransform +from pymc.logprob.transforms import Transform +from pymc.logprob.utils import rvs_in_graph from pymc.model.fgraph import ( ModelDeterministic, ModelFreeRV, @@ -40,7 +37,7 @@ parse_vars, prune_vars_detached_from_observed, ) -from pymc.pytensorf import _replace_vars_in_graphs, toposort_replace +from pymc.pytensorf import replace_vars_in_graphs, toposort_replace from pymc.util import get_transformed_name, get_untransformed_name @@ -122,44 +119,6 @@ def observe( return model_from_fgraph(fgraph) -def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]: - def replacement_fn(var, inner_replacements): - if var in replacements: - inner_replacements[var] = replacements[var] - - # Handle root inputs as those will never be passed to the replacement_fn - for inp in var.owner.inputs: - if inp.owner is None and inp in replacements: - inner_replacements[inp] = replacements[inp] - - return [var] - - replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn) - return replaced_graphs - - -def rvs_in_graph(vars: Sequence[Variable]) -> bool: - """Check if there are any rvs in the graph of vars""" - - from pymc.distributions.distribution import SymbolicRandomVariable - - def expand(r): - owner = r.owner - if owner: - inputs = list(reversed(owner.inputs)) - - if isinstance(owner.op, HasInnerGraph): - inputs += owner.op.inner_outputs - - return inputs - - return any( - node - for node in walk(vars, expand, False) - if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable)) - ) - - def do( model: Model, vars_to_interventions: Mapping[Union["str", TensorVariable], Any], @@ -263,7 +222,7 @@ def do( def change_value_transforms( model: Model, - vars_to_transforms: Mapping[ModelVariable, Union[RVTransform, None]], + vars_to_transforms: Mapping[ModelVariable, Union[Transform, None]], ) -> Model: """Change the value variables transforms in the model diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 8480d0dcfb..8537c4a62e 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -35,7 +35,7 @@ from pytensor import scalar from pytensor.compile import Function, Mode, get_mode from pytensor.gradient import grad -from pytensor.graph import Type, node_rewriter, rewrite_graph +from pytensor.graph import Type, rewrite_graph from pytensor.graph.basic import ( Apply, Constant, @@ -63,8 +63,6 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable from pymc.exceptions import NotConstantValueError -from pymc.logprob.transforms import RVTransform -from pymc.logprob.utils import CheckParameterValue from pymc.util import makeiter from pymc.vartypes import continuous_types, isgenerator, typefilter @@ -192,6 +190,8 @@ def walk_model( expand_fn A function that returns the next variable(s) to be traversed. """ + warnings.warn("walk_model will be removed in a future relase of PyMC", FutureWarning) + if stop_at_vars is None: stop_at_vars = set() @@ -206,197 +206,34 @@ def expand(var): yield from walk(graphs, expand, bfs=False) -def _replace_vars_in_graphs( - graphs: Iterable[TensorVariable], - replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]], - **kwargs, -) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]: - """Replace variables in graphs. - - This will *not* recompute test values. - - Parameters - ---------- - graphs - The graphs in which random variables are to be replaced. - replacement_fn - A callable called on each graph output that populates a replacement dictionary and returns - nodes that should be investigated further. - - Returns - ------- - Tuple containing the transformed graphs and a ``dict`` of the replacements - that were made. - """ - replacements = {} - - def expand_replace(var): - new_nodes = [] - if var.owner: - # Call replacement_fn to update replacements dict inplace and, optionally, - # specify new nodes that should also be walked for replacements. This - # includes `value` variables that are not simple input variables, and may - # contain other `random` variables in their graphs (e.g., IntervalTransform) - new_nodes.extend(replacement_fn(var, replacements)) - return new_nodes - - # This iteration populates the replacements - for var in walk_model(graphs, expand_fn=expand_replace, **kwargs): - pass - - if replacements: - inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] - equiv = {k: k for k in replacements.keys()} - equiv = clone_get_equiv(inputs, graphs, False, False, equiv) - - fg = FunctionGraph( - [equiv[i] for i in inputs], - [equiv[o] for o in graphs], - clone=False, - ) - - # replacements have to be done in reverse topological order so that nested - # expressions get recursively replaced correctly - toposort = fg.toposort() - sorted_replacements = sorted( - tuple(replacements.items()), - # Root inputs don't have owner, we give them negative priority -1 - key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner is not None else -1, - reverse=True, - ) - fg.replace_all(sorted_replacements, import_missing=True) - - graphs = list(fg.outputs) - - return graphs, replacements - - -def rvs_to_value_vars( +def replace_vars_in_graphs( graphs: Iterable[Variable], - apply_transforms: bool = True, - **kwargs, + replacements: Dict[Variable, Variable], ) -> List[Variable]: - """Clone and replace random variables in graphs with their value variables. - - This will *not* recompute test values in the resulting graphs. + """Replace variables in graphs. - Parameters - ---------- - graphs - The graphs in which to perform the replacements. - apply_transforms - If ``True``, apply each value variable's transform. + Graphs are cloned and not modified in place. """ - warnings.warn( - "rvs_to_value_vars is deprecated. Use model.replace_rvs_by_values instead", - FutureWarning, - ) - - def populate_replacements( - random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable] - ) -> List[TensorVariable]: - # Populate replacements dict with {rv: value} pairs indicating which graph - # RVs should be replaced by what value variables. - - value_var = getattr( - random_var.tag, "observations", getattr(random_var.tag, "value_var", None) - ) - - # No value variable to replace RV with - if value_var is None: - return [] - - transform = getattr(value_var.tag, "transform", None) - if transform is not None and apply_transforms: - # We want to replace uses of the RV by the back-transformation of its value - value_var = transform.backward(value_var, *random_var.owner.inputs) - - replacements[random_var] = value_var - - # Also walk the graph of the value variable to make any additional replacements - # if that is not a simple input variable - return [value_var] - - # Clone original graphs + # Clone graph and get equivalences inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] - equiv = clone_get_equiv(inputs, graphs, False, False, {}) - graphs = [equiv[n] for n in graphs] + equiv = {k: k for k in replacements.keys()} + equiv = clone_get_equiv(inputs, graphs, False, False, equiv) - graphs, _ = _replace_vars_in_graphs( - graphs, - replacement_fn=populate_replacements, - **kwargs, + fg = FunctionGraph( + [equiv[i] for i in inputs], + [equiv[o] for o in graphs], + clone=False, ) - return graphs - - -def replace_rvs_by_values( - graphs: Sequence[TensorVariable], - *, - rvs_to_values: Dict[TensorVariable, TensorVariable], - rvs_to_transforms: Optional[Dict[TensorVariable, RVTransform]] = None, - **kwargs, -) -> List[TensorVariable]: - """Clone and replace random variables in graphs with their value variables. - - This will *not* recompute test values in the resulting graphs. + # Filter replacement keys that are actually present in the graph + vars = fg.variables + final_replacements = tuple((k, v) for k, v in replacements.items() if k in vars) - Parameters - ---------- - graphs - The graphs in which to perform the replacements. - rvs_to_values - Mapping between the original graph RVs and respective value variables - rvs_to_transforms, optional - Mapping between the original graph RVs and respective value transforms - """ - - # Clone original graphs so that we don't modify variables in place - inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] - equiv = clone_get_equiv(inputs, graphs, False, False, {}) - graphs = [equiv[n] for n in graphs] - - # Get needed mappings for equivalent cloned variables - equiv_rvs_to_values = {} - equiv_rvs_to_transforms = {} - for rv, value in rvs_to_values.items(): - equiv_rv = equiv.get(rv, rv) - equiv_rvs_to_values[equiv_rv] = equiv.get(value, value) - if rvs_to_transforms is not None: - equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv] - - def poulate_replacements(rv, replacements): - # Populate replacements dict with {rv: value} pairs indicating which graph - # RVs should be replaced by what value variables. - - # No value variable to replace RV with - value = equiv_rvs_to_values.get(rv, None) - if value is None: - return [] - - if rvs_to_transforms is not None: - transform = equiv_rvs_to_transforms.get(rv, None) - if transform is not None: - # We want to replace uses of the RV by the back-transformation of its value - value = transform.backward(value, *rv.owner.inputs) - # The value may have a less precise type than the rv. In this case - # filter_variable will add a SpecifyShape to ensure they are consistent - value = rv.type.filter_variable(value, allow_convert=True) - value.name = rv.name - - replacements[rv] = value - # Also walk the graph of the value variable to make any additional - # replacements if that is not a simple input variable - return [value] - - graphs, _ = _replace_vars_in_graphs( - graphs, - replacement_fn=poulate_replacements, - **kwargs, - ) + # Replacements have to be done in reverse topological order so that nested + # expressions get recursively replaced correctly + toposort_replace(fg, final_replacements, reverse=True) - return graphs + return list(fg.outputs) def inputvars(a): @@ -899,48 +736,6 @@ def largest_common_dtype(tensors): return np.stack([np.ones((), dtype=dtype) for dtype in dtypes]).dtype -@node_rewriter(tracks=[CheckParameterValue]) -def local_remove_check_parameter(fgraph, node): - """Rewrite that removes CheckParameterValue - - This is used when compile_rv_inplace - """ - if isinstance(node.op, CheckParameterValue): - return [node.inputs[0]] - - -@node_rewriter(tracks=[CheckParameterValue]) -def local_check_parameter_to_ninf_switch(fgraph, node): - if not node.op.can_be_replaced_by_ninf: - return None - - logp_expr, *logp_conds = node.inputs - if len(logp_conds) > 1: - logp_cond = pt.all(logp_conds) - else: - (logp_cond,) = logp_conds - out = pt.switch(logp_cond, logp_expr, -np.inf) - out.name = node.op.msg - - if out.dtype != node.outputs[0].dtype: - out = pt.cast(out, node.outputs[0].dtype) - - return [out] - - -pytensor.compile.optdb["canonicalize"].register( - "local_remove_check_parameter", - local_remove_check_parameter, - use_db_name_as_tag=False, -) - -pytensor.compile.optdb["canonicalize"].register( - "local_check_parameter_to_ninf_switch", - local_check_parameter_to_ninf_switch, - use_db_name_as_tag=False, -) - - def find_rng_nodes( variables: Iterable[Variable], ) -> List[Union[RandomStateSharedVariable, RandomGeneratorSharedVariable]]: diff --git a/pymc/testing.py b/pymc/testing.py index 3eb1b7ba81..4315999789 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -36,14 +36,12 @@ from pymc.distributions.shape_utils import change_dist_size from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import icdf, logcdf, logp, transformed_conditional_logp -from pymc.logprob.utils import ParameterValueError, find_rvs_in_graph -from pymc.pytensorf import ( - compile_pymc, - floatX, - inputvars, - intX, +from pymc.logprob.utils import ( + ParameterValueError, local_check_parameter_to_ninf_switch, + rvs_in_graph, ) +from pymc.pytensorf import compile_pymc, floatX, inputvars, intX # This mode can be used for tests where model compilations takes the bulk of the runtime # AND where we don't care about posterior numerical or sampling stability (e.g., when @@ -952,6 +950,6 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable: def assert_no_rvs(vars: Sequence[Variable]) -> None: """Assert that there are no `MeasurableVariable` nodes in a graph.""" - rvs = find_rvs_in_graph(vars) + rvs = rvs_in_graph(vars) if rvs: raise AssertionError(f"RV found in graph: {rvs}") diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index e9027dcf3e..8196c2623c 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -27,7 +27,7 @@ import pymc.distributions.transforms as tr from pymc.logprob.basic import transformed_conditional_logp -from pymc.logprob.transforms import RVTransform +from pymc.logprob.transforms import Transform from pymc.pytensorf import floatX, jacobian from pymc.testing import ( Circ, @@ -632,7 +632,7 @@ def test_univariate_transform_multivariate_dist_raises(): def test_invalid_jacobian_broadcast_raises(): - class BuggyTransform(RVTransform): + class BuggyTransform(Transform): name = "buggy" def forward(self, value, *inputs): diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index 456a8f277a..b00c8ec7a0 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -63,8 +63,7 @@ transformed_conditional_logp, ) from pymc.logprob.transforms import LogTransform -from pymc.logprob.utils import rvs_to_value_vars, walk_model -from pymc.pytensorf import replace_rvs_by_values +from pymc.logprob.utils import replace_rvs_by_values from pymc.testing import assert_no_rvs @@ -93,9 +92,9 @@ def test_factorized_joint_logprob_basic(): # We need to replace the reference to `sigma` in `Y` with its value # variable ll_Y = logp(Y, y_value_var) - (ll_Y,), _ = rvs_to_value_vars( + (ll_Y,) = replace_rvs_by_values( [ll_Y], - initial_replacements={sigma: sigma_value_var}, + rvs_to_values={sigma: sigma_value_var}, ) total_ll_exp = logp(sigma, sigma_value_var) + ll_Y @@ -118,7 +117,7 @@ def test_factorized_joint_logprob_basic(): # There shouldn't be any `RandomVariable`s in the resulting graph assert_no_rvs(b_logp_combined) - res_ancestors = list(walk_model((b_logp_combined,), walk_past_rvs=True)) + res_ancestors = list(ancestors((b_logp_combined,))) assert b_value_var in res_ancestors assert c_value_var in res_ancestors assert a_value_var in res_ancestors @@ -274,7 +273,7 @@ def test_joint_logp_basic(): # There shouldn't be any `RandomVariable`s in the resulting graph assert_no_rvs(b_logp) - res_ancestors = list(walk_model((b_logp,))) + res_ancestors = list(ancestors((b_logp,))) assert b_value_var in res_ancestors assert c_value_var in res_ancestors assert a_value_var in res_ancestors diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 9acd3ab1ce..f0747a9338 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -57,10 +57,10 @@ ExpTransform, LocTransform, LogTransform, - RVTransform, ScaleTransform, SinhTransform, TanhTransform, + Transform, ) from pymc.logprob.utils import ParameterValueError from pymc.testing import Rplusbig, Vector, assert_no_rvs @@ -91,7 +91,7 @@ def logpdf(self, value): return res -class TestRVTransform: +class TestTransform: @pytest.mark.parametrize("ndim", (0, 1)) def test_fallback_log_jac_det(self, ndim): """ @@ -99,7 +99,7 @@ def test_fallback_log_jac_det(self, ndim): simple transformation: x**2 -> -log(2*x) """ - class SquareTransform(RVTransform): + class SquareTransform(Transform): name = "square" ndim_supp = ndim @@ -123,7 +123,7 @@ def backward(self, value, *inputs): @pytest.mark.parametrize("ndim", (None, 2)) def test_fallback_log_jac_det_undefined_ndim(self, ndim): - class SquareTransform(RVTransform): + class SquareTransform(Transform): name = "square" ndim_supp = ndim @@ -152,7 +152,6 @@ def test_chained_transform(self): transform_args_fn=lambda *inputs: pt.constant(loc), ), ], - base_op=pt.random.multivariate_normal, ) x = pt.random.multivariate_normal(np.zeros(3), np.eye(3)) diff --git a/tests/logprob/test_utils.py b/tests/logprob/test_utils.py index 320de6a36a..0e14f7e24f 100644 --- a/tests/logprob/test_utils.py +++ b/tests/logprob/test_utils.py @@ -34,127 +34,238 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import warnings - import numpy as np import pytensor -import pytensor.tensor as pt import pytest from pytensor import function +from pytensor import tensor as pt from pytensor.compile import get_default_mode -from pytensor.tensor.random.basic import normal, uniform +from pytensor.graph.basic import ancestors, equal_computations +from pytensor.tensor.random.op import RandomVariable import pymc as pm +from pymc import SymbolicRandomVariable +from pymc.distributions.transforms import Interval from pymc.logprob.abstract import MeasurableVariable -from pymc.logprob.basic import logp, transformed_conditional_logp +from pymc.logprob.basic import logp from pymc.logprob.utils import ( ParameterValueError, check_potential_measurability, dirac_delta, - rvs_to_value_vars, - walk_model, + replace_rvs_by_values, ) from pymc.testing import assert_no_rvs from tests.logprob.utils import create_pytensor_params, scipy_logprob_tester -def test_walk_model(): - d = pt.vector("d") - b = pt.vector("b") - c = uniform(0.0, d) - c.name = "c" - e = pt.log(c) - a = normal(e, b) - a.name = "a" - - test_graph = pt.exp(a + 1) - res = list(walk_model((test_graph,))) - assert a in res - assert c not in res - - res = list(walk_model((test_graph,), walk_past_rvs=True)) - assert a in res - assert c in res - - res = list(walk_model((test_graph,), walk_past_rvs=True, stop_at_vars={e})) - assert a in res - assert c not in res - - -def test_rvs_to_value_vars(): - a = pt.random.uniform(0.0, 1.0) - a.name = "a" - a.tag.value_var = a_value_var = a.clone() - - b = pt.random.uniform(0, a + 1.0) - b.name = "b" - b.tag.value_var = b_value_var = b.clone() - - c = pt.random.normal() - c.name = "c" - c.tag.value_var = c_value_var = c.clone() - - d = pt.log(c + b) + 2.0 - - initial_replacements = {b: b_value_var, c: c_value_var} - (res,), replaced = rvs_to_value_vars((d,), initial_replacements=initial_replacements) - - assert res.owner.op == pt.add - log_output = res.owner.inputs[0] - assert log_output.owner.op == pt.log - log_add_output = res.owner.inputs[0].owner.inputs[0] - assert log_add_output.owner.op == pt.add - c_output = log_add_output.owner.inputs[0] - - # We make sure that the random variables were replaced - # with their value variables - assert c_output == c_value_var - b_output = log_add_output.owner.inputs[1] - assert b_output == b_value_var - - # There shouldn't be any `RandomVariable`s in the resulting graph - assert_no_rvs(res) - - res_ancestors = list(walk_model((res,), walk_past_rvs=True)) - - assert b_value_var in res_ancestors - assert c_value_var in res_ancestors - assert a_value_var not in res_ancestors +class TestReplaceRVsByValues: + @pytest.mark.parametrize("symbolic_rv", (False, True)) + @pytest.mark.parametrize("apply_transforms", (True, False)) + def test_basic(self, symbolic_rv, apply_transforms): + # Interval transform between last two arguments + interval = ( + Interval(bounds_fn=lambda *args: (args[-2], args[-1])) if apply_transforms else None + ) + with pm.Model() as m: + a = pm.Uniform("a", 0.0, 1.0) + if symbolic_rv: + raw_b = pm.Uniform.dist(0, a + 1.0) + b = pm.Censored("b", raw_b, lower=0, upper=a + 1.0, transform=interval) + # If not True, another distribution has to be used + assert isinstance(b.owner.op, SymbolicRandomVariable) + else: + b = pm.Uniform("b", 0, a + 1.0, transform=interval) + c = pm.Normal("c") + d = pt.log(c + b) + 2.0 + + a_value_var = m.rvs_to_values[a] + assert m.rvs_to_transforms[a] is not None + + b_value_var = m.rvs_to_values[b] + c_value_var = m.rvs_to_values[c] + + (res,) = replace_rvs_by_values( + (d,), + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) -def test_rvs_to_value_vars_intermediate_rv(): - """Test that function replaces values above an intermediate RV.""" - a = pt.random.uniform(0.0, 1.0) - a.name = "a" - a.tag.value_var = a_value_var = a.clone() + assert res.owner.op == pt.add + log_output = res.owner.inputs[0] + assert log_output.owner.op == pt.log + log_add_output = res.owner.inputs[0].owner.inputs[0] + assert log_add_output.owner.op == pt.add + c_output = log_add_output.owner.inputs[0] + + # We make sure that the random variables were replaced + # with their value variables + assert c_output == c_value_var + b_output = log_add_output.owner.inputs[1] + # When transforms are applied, the input is the back-transformation of the value_var, + # otherwise it is the value_var itself + if apply_transforms: + assert b_output != b_value_var + else: + assert b_output == b_value_var + + res_ancestors = list(ancestors((res,))) + res_rv_ancestors = [ + v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) + ] + + # There shouldn't be any `RandomVariable`s in the resulting graph + assert len(res_rv_ancestors) == 0 + assert b_value_var in res_ancestors + assert c_value_var in res_ancestors + # When transforms are used, `d` depends on `a` through the back-transformation of + # `b`, otherwise there is no direct connection between `d` and `a` + if apply_transforms: + assert a_value_var in res_ancestors + else: + assert a_value_var not in res_ancestors + + def test_intermediate_rv(self): + """Test that function replaces values above an intermediate RV.""" + a = pt.random.uniform(0.0, 1.0) + a.name = "a" + a.tag.value_var = a_value_var = a.clone() + + b = pt.random.uniform(0, a + 1.0) + b.name = "b" + b.tag.value_var = b.clone() + + c = pt.random.normal() + c.name = "c" + c.tag.value_var = c_value_var = c.clone() + + d = pt.log(c + b) + 2.0 + + initial_replacements = {a: a_value_var, c: c_value_var} + (res,) = replace_rvs_by_values((d,), rvs_to_values=initial_replacements) + + # Assert that the only RandomVariable that remains in the graph is `b` + res_ancestors = list(ancestors((res,))) + + assert ( + len( + list( + n + for n in res_ancestors + if n.owner and isinstance(n.owner.op, MeasurableVariable) + ) + ) + == 1 + ) - b = pt.random.uniform(0, a + 1.0) - b.name = "b" - b.tag.value_var = b.clone() + assert c_value_var in res_ancestors + assert a_value_var in res_ancestors - c = pt.random.normal() - c.name = "c" - c.tag.value_var = c_value_var = c.clone() + def test_unvalued_rv_model(self): + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Normal.dist(x) + z = pm.Normal("z", y) + out = z + y - d = pt.log(c + b) + 2.0 + x_value = m.rvs_to_values[x] + z_value = m.rvs_to_values[z] - initial_replacements = {a: a_value_var, c: c_value_var} - (res,), replaced = rvs_to_value_vars((d,), initial_replacements=initial_replacements) + (res,) = replace_rvs_by_values( + (out,), + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) - # Assert that the only RandomVariable that remains in the graph is `b` - res_ancestors = list(walk_model((res,), walk_past_rvs=True)) + assert res.owner.op == pt.add + assert res.owner.inputs[0] is z_value + res_y = res.owner.inputs[1] + # Graph should have be cloned, and therefore y and res_y should have different ids + assert res_y is not y + assert res_y.owner.op == pt.random.normal + assert res_y.owner.inputs[3] is x_value + + def test_no_change_inplace(self): + # Test that calling rvs_to_value_vars in models with nested transformations + # does not change the original rvs in place. See issue #5172 + with pm.Model() as m: + one = pm.LogNormal("one", mu=0) + two = pm.LogNormal("two", mu=pt.log(one)) + + # We add potentials or deterministics that are not in topological order + pm.Potential("two_pot", two) + pm.Potential("one_pot", one) + + before = pytensor.clone_replace(m.free_RVs) + + # This call would change the model free_RVs in place in #5172 + replace_rvs_by_values( + m.potentials, + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, + ) - assert ( - len( - list(n for n in res_ancestors if n.owner and isinstance(n.owner.op, MeasurableVariable)) + after = pytensor.clone_replace(m.free_RVs) + assert equal_computations(before, after) + + @pytest.mark.parametrize("reversed", (False, True)) + def test_interdependent_transformed_rvs(self, reversed): + # Test that nested transformed variables, whose transformed values depend on other + # RVs are properly replaced + with pm.Model() as m: + transform = pm.distributions.transforms.Interval( + bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) + ) + x = pm.Uniform("x", lower=0, upper=1, transform=transform) + y = pm.Uniform("y", lower=0, upper=x, transform=transform) + z = pm.Uniform("z", lower=0, upper=y, transform=transform) + w = pm.Uniform("w", lower=0, upper=z, transform=transform) + + rvs = [x, y, z, w] + if reversed: + rvs = rvs[::-1] + + transform_values = replace_rvs_by_values( + rvs, + rvs_to_values=m.rvs_to_values, + rvs_to_transforms=m.rvs_to_transforms, ) - == 1 - ) - assert c_value_var in res_ancestors - assert a_value_var in res_ancestors + for transform_value in transform_values: + assert_no_rvs(transform_value) + + if reversed: + transform_values = transform_values[::-1] + transform_values_fn = m.compile_fn(transform_values, point_fn=False) + + x_interval_test_value = np.random.rand() + y_interval_test_value = np.random.rand() + z_interval_test_value = np.random.rand() + w_interval_test_value = np.random.rand() + + # The 3 Nones correspond to unused rng, dtype and size arguments + expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval() + expected_y = transform.backward( + y_interval_test_value, None, None, None, 0, expected_x + ).eval() + expected_z = transform.backward( + z_interval_test_value, None, None, None, 0, expected_y + ).eval() + expected_w = transform.backward( + w_interval_test_value, None, None, None, 0, expected_z + ).eval() + + np.testing.assert_allclose( + transform_values_fn( + x_interval__=x_interval_test_value, + y_interval__=y_interval_test_value, + z_interval__=z_interval_test_value, + w_interval__=w_interval_test_value, + ), + [expected_x, expected_y, expected_z, expected_w], + ) def test_CheckParameter(): diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index fac7b70462..f430d6b5e4 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -24,9 +24,8 @@ from pytensor import scan, shared from pytensor.compile.builders import OpFromGraph -from pytensor.graph.basic import Variable, equal_computations +from pytensor.graph.basic import Variable from pytensor.tensor.random.basic import normal, uniform -from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.var import RandomStateSharedVariable from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 from pytensor.tensor.variable import TensorVariable @@ -35,23 +34,19 @@ from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import SymbolicRandomVariable -from pymc.distributions.transforms import Interval from pymc.exceptions import NotConstantValueError from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( - _replace_vars_in_graphs, collect_default_updates, compile_pymc, constant_fold, convert_observed_data, extract_obs_data, replace_rng_nodes, - replace_rvs_by_values, + replace_vars_in_graphs, reseed_rngs, - rvs_to_value_vars, walk_model, ) -from pymc.testing import assert_no_rvs from pymc.vartypes import int_types @@ -286,21 +281,24 @@ def test_walk_model(): test_graph = pt.exp(e + 1) - res = list(walk_model((test_graph,))) + with pytest.warns(FutureWarning): + res = list(walk_model((test_graph,))) assert a in res assert b in res assert c in res assert d in res assert e in res - res = list(walk_model((test_graph,), stop_at_vars={c})) + with pytest.warns(FutureWarning): + res = list(walk_model((test_graph,), stop_at_vars={c})) assert a not in res assert b not in res assert c in res assert d in res assert e in res - res = list(walk_model((test_graph,), stop_at_vars={b})) + with pytest.warns(FutureWarning): + res = list(walk_model((test_graph,), stop_at_vars={b})) assert a not in res assert b in res assert c in res @@ -668,211 +666,15 @@ def test_constant_fold_raises(): assert tuple(res[1].eval()) == (5,) -class TestReplaceRVsByValues: - @pytest.mark.parametrize("symbolic_rv", (False, True)) - @pytest.mark.parametrize("apply_transforms", (True, False)) - @pytest.mark.parametrize("test_deprecated_fn", (True, False)) - def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn): - # Interval transform between last two arguments - interval = ( - Interval(bounds_fn=lambda *args: (args[-2], args[-1])) if apply_transforms else None - ) - - with pm.Model() as m: - a = pm.Uniform("a", 0.0, 1.0) - if symbolic_rv: - raw_b = pm.Uniform.dist(0, a + 1.0) - b = pm.Censored("b", raw_b, lower=0, upper=a + 1.0, transform=interval) - # If not True, another distribution has to be used - assert isinstance(b.owner.op, SymbolicRandomVariable) - else: - b = pm.Uniform("b", 0, a + 1.0, transform=interval) - c = pm.Normal("c") - d = pt.log(c + b) + 2.0 - - a_value_var = m.rvs_to_values[a] - assert m.rvs_to_transforms[a] is not None - - b_value_var = m.rvs_to_values[b] - c_value_var = m.rvs_to_values[c] - - if test_deprecated_fn: - with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): - (res,) = rvs_to_value_vars((d,), apply_transforms=apply_transforms) - else: - (res,) = replace_rvs_by_values( - (d,), - rvs_to_values=m.rvs_to_values, - rvs_to_transforms=m.rvs_to_transforms, - ) - - assert res.owner.op == pt.add - log_output = res.owner.inputs[0] - assert log_output.owner.op == pt.log - log_add_output = res.owner.inputs[0].owner.inputs[0] - assert log_add_output.owner.op == pt.add - c_output = log_add_output.owner.inputs[0] - - # We make sure that the random variables were replaced - # with their value variables - assert c_output == c_value_var - b_output = log_add_output.owner.inputs[1] - # When transforms are applied, the input is the back-transformation of the value_var, - # otherwise it is the value_var itself - if apply_transforms: - assert b_output != b_value_var - else: - assert b_output == b_value_var - - res_ancestors = list(walk_model((res,))) - res_rv_ancestors = [ - v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) - ] - - # There shouldn't be any `RandomVariable`s in the resulting graph - assert len(res_rv_ancestors) == 0 - assert b_value_var in res_ancestors - assert c_value_var in res_ancestors - # When transforms are used, `d` depends on `a` through the back-transformation of - # `b`, otherwise there is no direct connection between `d` and `a` - if apply_transforms: - assert a_value_var in res_ancestors - else: - assert a_value_var not in res_ancestors - - @pytest.mark.parametrize("test_deprecated_fn", (True, False)) - def test_unvalued_rv(self, test_deprecated_fn): - with pm.Model() as m: - x = pm.Normal("x") - y = pm.Normal.dist(x) - z = pm.Normal("z", y) - out = z + y - - x_value = m.rvs_to_values[x] - z_value = m.rvs_to_values[z] - - if test_deprecated_fn: - with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): - (res,) = rvs_to_value_vars((out,)) - else: - (res,) = replace_rvs_by_values( - (out,), - rvs_to_values=m.rvs_to_values, - rvs_to_transforms=m.rvs_to_transforms, - ) - - assert res.owner.op == pt.add - assert res.owner.inputs[0] is z_value - res_y = res.owner.inputs[1] - # Graph should have be cloned, and therefore y and res_y should have different ids - assert res_y is not y - assert res_y.owner.op == pt.random.normal - assert res_y.owner.inputs[3] is x_value - - @pytest.mark.parametrize("test_deprecated_fn", (True, False)) - def test_no_change_inplace(self, test_deprecated_fn): - # Test that calling rvs_to_value_vars in models with nested transformations - # does not change the original rvs in place. See issue #5172 - with pm.Model() as m: - one = pm.LogNormal("one", mu=0) - two = pm.LogNormal("two", mu=pt.log(one)) - - # We add potentials or deterministics that are not in topological order - pm.Potential("two_pot", two) - pm.Potential("one_pot", one) - - before = pytensor.clone_replace(m.free_RVs) - - # This call would change the model free_RVs in place in #5172 - if test_deprecated_fn: - with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): - rvs_to_value_vars(m.potentials) - else: - replace_rvs_by_values( - m.potentials, - rvs_to_values=m.rvs_to_values, - rvs_to_transforms=m.rvs_to_transforms, - ) - - after = pytensor.clone_replace(m.free_RVs) - assert equal_computations(before, after) - - @pytest.mark.parametrize("test_deprecated_fn", (True, False)) - @pytest.mark.parametrize("reversed", (False, True)) - def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn): - # Test that nested transformed variables, whose transformed values depend on other - # RVs are properly replaced - with pm.Model() as m: - transform = pm.distributions.transforms.Interval( - bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) - ) - x = pm.Uniform("x", lower=0, upper=1, transform=transform) - y = pm.Uniform("y", lower=0, upper=x, transform=transform) - z = pm.Uniform("z", lower=0, upper=y, transform=transform) - w = pm.Uniform("w", lower=0, upper=z, transform=transform) - - rvs = [x, y, z, w] - if reversed: - rvs = rvs[::-1] - - if test_deprecated_fn: - with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"): - transform_values = rvs_to_value_vars(rvs) - else: - transform_values = replace_rvs_by_values( - rvs, - rvs_to_values=m.rvs_to_values, - rvs_to_transforms=m.rvs_to_transforms, - ) - - for transform_value in transform_values: - assert_no_rvs(transform_value) - - if reversed: - transform_values = transform_values[::-1] - transform_values_fn = m.compile_fn(transform_values, point_fn=False) - - x_interval_test_value = np.random.rand() - y_interval_test_value = np.random.rand() - z_interval_test_value = np.random.rand() - w_interval_test_value = np.random.rand() - - # The 3 Nones correspond to unused rng, dtype and size arguments - expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval() - expected_y = transform.backward( - y_interval_test_value, None, None, None, 0, expected_x - ).eval() - expected_z = transform.backward( - z_interval_test_value, None, None, None, 0, expected_y - ).eval() - expected_w = transform.backward( - w_interval_test_value, None, None, None, 0, expected_z - ).eval() - - np.testing.assert_allclose( - transform_values_fn( - x_interval__=x_interval_test_value, - y_interval__=y_interval_test_value, - z_interval__=z_interval_test_value, - w_interval__=w_interval_test_value, - ), - [expected_x, expected_y, expected_z, expected_w], - ) - - def test_replace_input(self): - inp = shared(0.0, name="inp") - x = pm.Normal.dist(inp) - - assert x.eval() < 50 - - new_inp = inp + 100 +def test_replace_vars_in_graphs(): + inp = shared(0.0, name="inp") + x = pm.Normal.dist(inp) - def replacement_fn(var, replacements): - if var is x: - replacements[x.owner.inputs[3]] = new_inp + assert x.eval() < 50 - return [] + new_inp = inp + 100 - [new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn) + replacements = {x.owner.inputs[3]: new_inp} + [new_x] = replace_vars_in_graphs([x], replacements=replacements) - assert new_x.eval() > 50 + assert new_x.eval() > 50 diff --git a/tests/test_util.py b/tests/test_util.py index 1b41d3a734..faf22aabf2 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -22,7 +22,7 @@ import pymc as pm -from pymc.distributions.transforms import RVTransform +from pymc.distributions.transforms import Transform from pymc.util import ( UNSET, _get_seeds_per_chain, @@ -40,7 +40,7 @@ class TestTransformName: transform_name = "test" def test_get_transformed_name(self): - class NewTransform(RVTransform): + class NewTransform(Transform): name = self.transform_name def forward(self, value):