diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 360a8199..4deda063 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -10,6 +10,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.16.1 # CI was failing to resolve + - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 360a8199..4deda063 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,6 +10,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.16.1 # CI was failing to resolve + - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 77dd3b22..62f1293c 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -16,7 +16,7 @@ from pymc_experimental import gp, statespace, utils from pymc_experimental.distributions import * from pymc_experimental.inference.fit import fit -from pymc_experimental.model.marginal_model import MarginalModel +from pymc_experimental.model.marginal.marginal_model import MarginalModel from pymc_experimental.model.model_api import as_model from pymc_experimental.version import __version__ diff --git a/pymc_experimental/model/marginal/__init__.py b/pymc_experimental/model/marginal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymc_experimental/model/marginal/distributions.py b/pymc_experimental/model/marginal/distributions.py new file mode 100644 index 00000000..661665e9 --- /dev/null +++ b/pymc_experimental/model/marginal/distributions.py @@ -0,0 +1,276 @@ +from collections.abc import Sequence + +import numpy as np +import pytensor.tensor as pt + +from pymc.distributions import Bernoulli, Categorical, DiscreteUniform +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.logprob.basic import conditional_logp, logp +from pymc.pytensorf import constant_fold +from pytensor import Variable +from pytensor.compile.builders import OpFromGraph +from pytensor.compile.mode import Mode +from pytensor.graph import Op, vectorize_graph +from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.scan import map as scan_map +from pytensor.scan import scan +from pytensor.tensor import TensorVariable + +from pymc_experimental.distributions import DiscreteMarkovChain + + +class MarginalRV(OpFromGraph, MeasurableOp): + """Base class for Marginalized RVs""" + + def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None: + self.dims_connections = dims_connections + super().__init__(*args, **kwargs) + + @property + def support_axes(self) -> tuple[tuple[int]]: + """Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable.""" + marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp + support_axes_vars = [] + for dims_connection in self.dims_connections: + ndim = len(dims_connection) + marginalized_supp_axes = ndim - marginalized_ndim_supp + support_axes_vars.append( + tuple( + -i + for i, dim in enumerate(reversed(dims_connection), start=1) + if (dim is None or dim > marginalized_supp_axes) + ) + ) + return tuple(support_axes_vars) + + +class MarginalFiniteDiscreteRV(MarginalRV): + """Base class for Marginalized Finite Discrete RVs""" + + +class MarginalDiscreteMarkovChainRV(MarginalRV): + """Base class for Marginalized Discrete Markov Chain RVs""" + + +def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: + op = rv.owner.op + dist_params = rv.owner.op.dist_params(rv.owner) + if isinstance(op, Bernoulli): + return (0, 1) + elif isinstance(op, Categorical): + [p_param] = dist_params + [p_param_length] = constant_fold([p_param.shape[-1]]) + return tuple(range(p_param_length)) + elif isinstance(op, DiscreteUniform): + lower, upper = constant_fold(dist_params) + return tuple(np.arange(lower, upper + 1)) + elif isinstance(op, DiscreteMarkovChain): + P, *_ = dist_params + return tuple(range(pt.get_vector_length(P[-1]))) + + raise NotImplementedError(f"Cannot compute domain for op {op}") + + +def reduce_batch_dependent_logps( + dependent_dims_connections: Sequence[tuple[int | None, ...]], + dependent_ops: Sequence[Op], + dependent_logps: Sequence[TensorVariable], +) -> TensorVariable: + """Combine the logps of dependent RVs and align them with the marginalized logp. + + This requires reducing extra batch dims and transposing when they are not aligned. + + idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1 + pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5)) + pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3)) + + marginalize(idx) + + The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)] + which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp, + as well as transpose the remaining axis of dep1 logp before adding the two element-wise. + + """ + from pymc_experimental.model.marginal.graph_analysis import get_support_axes + + reduced_logps = [] + for dependent_op, dependent_logp, dependent_dims_connection in zip( + dependent_ops, dependent_logps, dependent_dims_connections + ): + if dependent_logp.type.ndim > 0: + # Find which support axis implied by the MarginalRV need to be reduced + # Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs) + dep_supp_axes = get_support_axes(dependent_op)[0] + + # Dependent RV support axes are already collapsed in the logp, so we ignore them + supp_axes = [ + -i + for i, dim in enumerate(reversed(dependent_dims_connection), start=1) + if (dim is None and -i not in dep_supp_axes) + ] + dependent_logp = dependent_logp.sum(supp_axes) + + # Finally, we need to align the dependent logp batch dimensions with the marginalized logp + dims_alignment = [dim for dim in dependent_dims_connection if dim is not None] + dependent_logp = dependent_logp.transpose(*dims_alignment) + + reduced_logps.append(dependent_logp) + + reduced_logp = pt.add(*reduced_logps) + return reduced_logp + + +def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable: + """Align the logp with the order specified in dims.""" + dims_alignment = [dim for dim in dims if dim is not None] + return logp.transpose(*dims_alignment) + + +def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: + """Inline the inner graph (outputs) of an OpFromGraph Op. + + Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" + the inner graph. + """ + return clone_replace( + op.inner_outputs, + replace=tuple(zip(op.inner_inputs, inputs)), + ) + + +DUMMY_ZERO = pt.constant(0, name="dummy_zero") + + +@_logprob.register(MarginalFiniteDiscreteRV) +def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs): + # Clone the inner RV graph of the Marginalized RV + marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs) + + # Obtain the joint_logp graph of the inner RV graph + inner_rv_values = dict(zip(inner_rvs, values)) + marginalized_vv = marginalized_rv.clone() + rv_values = inner_rv_values | {marginalized_rv: marginalized_vv} + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) + + # Reduce logp dimensions corresponding to broadcasted variables + marginalized_logp = logps_dict.pop(marginalized_vv) + joint_logp = marginalized_logp + reduce_batch_dependent_logps( + dependent_dims_connections=op.dims_connections, + dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs], + dependent_logps=[logps_dict[value] for value in values], + ) + + # Compute the joint_logp for all possible n values of the marginalized RV. We assume + # each original dimension is independent so that it suffices to evaluate the graph + # n times, once with each possible value of the marginalized RV replicated across + # batched dimensions of the marginalized RV + + # PyMC does not allow RVs in the logp graph, even if we are just using the shape + marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) + marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) + marginalized_rv_domain_tensor = pt.moveaxis( + pt.full( + (*marginalized_rv_shape, len(marginalized_rv_domain)), + marginalized_rv_domain, + dtype=marginalized_rv.dtype, + ), + -1, + 0, + ) + + try: + joint_logps = vectorize_graph( + joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor} + ) + except Exception: + # Fallback to Scan + def logp_fn(marginalized_rv_const, *non_sequences): + return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const}) + + joint_logps, _ = scan_map( + fn=logp_fn, + sequences=marginalized_rv_domain_tensor, + non_sequences=[*values, *inputs], + mode=Mode().including("local_remove_check_parameter"), + ) + + joint_logp = pt.logsumexp(joint_logps, axis=0) + + # Align logp with non-collapsed batch dimensions of first RV + joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp) + + # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise + dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) + return joint_logp, *dummy_logps + + +@_logprob.register(MarginalDiscreteMarkovChainRV) +def marginal_hmm_logp(op, values, *inputs, **kwargs): + chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs) + + P, n_steps_, init_dist_, rng = chain_rv.owner.inputs + domain = pt.arange(P.shape[-1], dtype="int32") + + # Construct logp in two steps + # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission) + + # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating + # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise, + # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step. + chain_value = chain_rv.clone() + dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value}) + logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) + + # Reduce and add the batch dims beyond the chain dimension + reduced_logp_emissions = reduce_batch_dependent_logps( + dependent_dims_connections=op.dims_connections, + dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs], + dependent_logps=[logp_emissions_dict[value] for value in values], + ) + + # Add a batch dimension for the domain of the chain + chain_shape = constant_fold(tuple(chain_rv.shape)) + batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) + batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) + + # Step 2: Compute the transition probabilities + # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) + # We do it entirely in logs, though. + + # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) + # under the initial distribution. This is robust to everything the user can throw at it. + init_dist_value = init_dist_.type() + logp_init_dist = logp(init_dist_, init_dist_value) + # There is a degerate batch dim for lags=1 (the only supported case), + # that we have to work around, by expanding the batch value and then squeezing it out of the logp + batch_logp_init_dist = vectorize_graph( + logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]} + ).squeeze(1) + log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0] + + def step_alpha(logp_emission, log_alpha, log_P): + step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0) + return logp_emission + step_log_prob + + P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2) + log_P = pt.shape_padright(pt.log(P), P_bcast_dims) + log_alpha_seq, _ = scan( + step_alpha, + non_sequences=[log_P], + outputs_info=[log_alpha_init], + # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value + sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0), + ) + # Final logp is just the sum of the last scan state + joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) + + # Align logp with non-collapsed batch dimensions of first RV + remaining_dims_first_emission = list(op.dims_connections[0]) + # The last dim of chain_rv was removed when computing the logp + remaining_dims_first_emission.remove(chain_rv.type.ndim - 1) + joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp) + + # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first + # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream. + dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) + return joint_logp, *dummy_logps diff --git a/pymc_experimental/model/marginal/graph_analysis.py b/pymc_experimental/model/marginal/graph_analysis.py new file mode 100644 index 00000000..62ac2abb --- /dev/null +++ b/pymc_experimental/model/marginal/graph_analysis.py @@ -0,0 +1,372 @@ +import itertools + +from collections.abc import Sequence +from itertools import zip_longest + +from pymc import SymbolicRandomVariable +from pytensor.compile import SharedVariable +from pytensor.graph import Constant, Variable, ancestors +from pytensor.graph.basic import io_toposort +from pytensor.tensor import TensorType, TensorVariable +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.rewriting.subtensor import is_full_slice +from pytensor.tensor.shape import Shape +from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list +from pytensor.tensor.type_other import NoneTypeT + +from pymc_experimental.model.marginal.distributions import MarginalRV + + +def static_shape_ancestors(vars): + """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" + return [ + var + for var in ancestors(vars) + if ( + var.owner + and isinstance(var.owner.op, Shape) + # All static dims lengths of Shape input are known + and None not in var.owner.inputs[0].type.shape + ) + ] + + +def find_conditional_input_rvs(output_rvs, all_rvs): + """Find conditionally indepedent input RVs.""" + blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] + blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs)) + return [ + var + for var in ancestors(output_rvs, blockers=blockers) + if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable)) + ] + + +def is_conditional_dependent( + dependent_rv: TensorVariable, dependable_rv: TensorVariable, all_rvs +) -> bool: + """Check if dependent_rv is conditionall dependent on dependable_rv, + given all conditionally independent all_rvs""" + + return dependable_rv in find_conditional_input_rvs((dependent_rv,), all_rvs) + + +def find_conditional_dependent_rvs(dependable_rv, all_rvs): + """Find rvs than depend on dependable""" + return [ + rv + for rv in all_rvs + if (rv is not dependable_rv and is_conditional_dependent(rv, dependable_rv, all_rvs)) + ] + + +def get_support_axes(op) -> tuple[tuple[int, ...], ...]: + if isinstance(op, MarginalRV): + return op.support_axes + else: + # For vanilla RVs, the support axes are the last ndim_supp + return (tuple(range(-op.ndim_supp, 0)),) + + +def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]: + """Find the output axis and dimensionality of the advanced indexing group (i.e., array indexing). + + There is a special case: when there are non-consecutive advanced indexing groups, the advanced indexing + group is always moved to the front. + + See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + """ + adv_group_axis = None + simple_group_after_adv = False + for axis, idx in enumerate(idxs): + if isinstance(idx.type, TensorType): + if simple_group_after_adv: + # Special non-consecutive case + adv_group_axis = 0 + break + elif adv_group_axis is None: + adv_group_axis = axis + elif adv_group_axis is not None: + # Special non-consecutive case + simple_group_after_adv = True + + adv_group_ndim = max(idx.type.ndim for idx in idxs if isinstance(idx.type, TensorType)) + return adv_group_axis, adv_group_ndim + + +DIMS = tuple[int | None, ...] +VAR_DIMS = dict[Variable, DIMS] + + +def _broadcast_dims( + inputs_dims: Sequence[DIMS], +) -> DIMS: + output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0) + + # Add missing dims + inputs_dims = [ + (None,) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims + ] + + # Find which known dims show in the output, while checking no mixing + output_dims = [] + for inputs_dim in zip(*inputs_dims): + output_dim = None + for input_dim in inputs_dim: + if input_dim is None: + continue + if output_dim is not None and output_dim != input_dim: + raise ValueError("Different known dimensions mixed via broadcasting") + output_dim = input_dim + output_dims.append(output_dim) + + # Check for duplicates + known_dims = [dim for dim in output_dims if dim is not None] + if len(known_dims) > len(set(known_dims)): + raise ValueError("Same known dimension used in different axis after broadcasting") + + return tuple(output_dims) + + +def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR_DIMS: + for node in io_toposort(input_vars, output_vars): + inputs_dims = [ + var_dims.get(inp, ((None,) * inp.type.ndim) if hasattr(inp.type, "ndim") else ()) + for inp in node.inputs + ] + + if all(dim is None for input_dims in inputs_dims for dim in input_dims): + # None of the inputs are related to the batch_axes of the input_vars + continue + + elif isinstance(node.op, DimShuffle): + [input_dims] = inputs_dims + output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order) + var_dims[node.outputs[0]] = output_dims + + elif isinstance(node.op, MarginalRV) or ( + isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None + ): + # MarginalRV and SymbolicRandomVariables without signature are a wild-card, + # so we need to introspect the inner graph. + op = node.op + inner_inputs = op.inner_inputs + inner_outputs = op.inner_outputs + + inner_var_dims = _subgraph_batch_dim_connection( + dict(zip(inner_inputs, inputs_dims)), inner_inputs, inner_outputs + ) + + support_axes = iter(get_support_axes(op)) + if isinstance(op, MarginalRV): + # The first output is the marginalized variable for which we don't compute support axes + support_axes = itertools.chain(((),), support_axes) + for i, (out, inner_out) in enumerate(zip(node.outputs, inner_outputs)): + if not isinstance(out.type, TensorType): + continue + support_axes_out = next(support_axes) + + if inner_out in inner_var_dims: + out_dims = inner_var_dims[inner_out] + if any( + dim is not None for dim in (out_dims[axis] for axis in support_axes_out) + ): + raise ValueError(f"Known dim corresponds to core dimension of {node.op}") + var_dims[out] = out_dims + + elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable): + # NOTE: User-provided CustomDist may not respect core dimensions on the left. + + if isinstance(node.op, Elemwise): + op_batch_ndim = node.outputs[0].type.ndim + else: + op_batch_ndim = node.op.batch_ndim(node) + + if isinstance(node.op, SymbolicRandomVariable): + # SymbolicRandomVariable don't have explicit expand_dims unlike the other Ops considered in this + [_, _, param_idxs], _ = node.op.get_input_output_type_idxs( + node.op.extended_signature + ) + for param_idx, param_core_ndim in zip(param_idxs, node.op.ndims_params): + param_dims = inputs_dims[param_idx] + missing_ndim = op_batch_ndim - (len(param_dims) - param_core_ndim) + inputs_dims[param_idx] = (None,) * missing_ndim + param_dims + + if any( + dim is not None for input_dim in inputs_dims for dim in input_dim[op_batch_ndim:] + ): + raise ValueError( + f"Use of known dimensions as core dimensions of op {node.op} not supported." + ) + + batch_dims = _broadcast_dims( + tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims) + ) + for out in node.outputs: + if isinstance(out.type, TensorType): + core_ndim = out.type.ndim - op_batch_ndim + output_dims = batch_dims + (None,) * core_ndim + var_dims[out] = output_dims + + elif isinstance(node.op, CAReduce): + [input_dims] = inputs_dims + + axes = node.op.axis + if isinstance(axes, int): + axes = (axes,) + elif axes is None: + axes = tuple(range(node.inputs[0].type.ndim)) + + if any(input_dims[axis] for axis in axes): + raise ValueError( + f"Use of known dimensions as reduced dimensions of op {node.op} not supported." + ) + + output_dims = [dims for i, dims in enumerate(input_dims) if i not in axes] + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, Subtensor): + value_dims, *keys_dims = inputs_dims + # Dims in basic indexing must belong to the value variable, since indexing keys are always scalar + assert not any(dim is None for dim in keys_dims) + keys = get_idx_list(node.inputs, node.op.idx_list) + + output_dims = [] + for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if idx == slice(None): + # Dim is kept + output_dims.append(value_dim) + elif value_dim is not None: + raise ValueError( + "Partial slicing or indexing of known dimensions not supported." + ) + elif isinstance(idx, slice): + # Unknown dimensions kept by partial slice. + output_dims.append(None) + + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, AdvancedSubtensor): + # AdvancedSubtensor dimensions can show up as both the indexed variable and indexing variables + value, *keys = node.inputs + value_dims, *keys_dims = inputs_dims + + # Just to stay sane, we forbid any boolean indexing... + if any(isinstance(idx.type, TensorType) and idx.type.dtype == "bool" for idx in keys): + raise NotImplementedError( + f"Array indexing with boolean variables in node {node} not supported." + ) + + if any(dim is not None for dim in value_dims) and any( + dim is not None for key_dims in keys_dims for dim in key_dims + ): + # Both indexed variable and indexing variables have known dimensions + # I am to lazy to think through these, so we raise for now. + raise NotImplementedError( + f"Simultaneous use of known dimensions in indexed and indexing variables in node {node} not supported." + ) + + adv_group_axis, adv_group_ndim = _advanced_indexing_axis_and_ndim(keys) + + if any(dim is not None for dim in value_dims): + # Indexed variable has known dimensions + + if any(isinstance(idx.type, NoneTypeT) for idx in keys): + # Corresponds to an expand_dims, for now not supported + raise NotImplementedError( + f"Advanced indexing in node {node} which introduces new axis is not supported." + ) + + non_adv_dims = [] + for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if is_full_slice(idx): + non_adv_dims.append(value_dim) + elif value_dim is not None: + # We are trying to partially slice or index a known dimension + raise ValueError( + "Partial slicing or advanced integer indexing of known dimensions not supported." + ) + elif isinstance(idx, slice): + # Unknown dimensions kept by partial slice. + non_adv_dims.append(None) + + # Insert unknown dimensions corresponding to advanced indexing + output_dims = tuple( + non_adv_dims[:adv_group_axis] + + [None] * adv_group_ndim + + non_adv_dims[adv_group_axis:] + ) + + else: + # Indexing keys have known dimensions. + # Only array indices can have dimensions, the rest are just slices or newaxis + + # Advanced indexing variables broadcast together, so we apply same rules as in Elemwise + adv_dims = _broadcast_dims(keys_dims) + + start_non_adv_dims = (None,) * adv_group_axis + end_non_adv_dims = (None,) * ( + node.outputs[0].type.ndim - adv_group_axis - adv_group_ndim + ) + output_dims = start_non_adv_dims + adv_dims + end_non_adv_dims + + var_dims[node.outputs[0]] = output_dims + + else: + raise NotImplementedError(f"Marginalization through operation {node} not supported.") + + return var_dims + + +def subgraph_batch_dim_connection(input_var, output_vars) -> list[DIMS]: + """Identify how the batch dims of input map to the batch dimensions of the output_rvs. + + Example: + ------- + In the example below `idx` has two batch dimensions (indexed 0, 1 from left to right). + The two uncommented dependent variables each have 2 batch dimensions where each entry + results from a mapping of a single entry from one of these batch dimensions. + + This mapping is transposed in the case of the first dependent variable, and shows up in + the same order for the second dependent variable. Each of the variables as a further + batch dimension encoded as `None`. + + The commented out third dependent variable combines information from the batch dimensions + of `idx` via the `sum` operation. A `ValueError` would be raised if we requested the + connection of batch dims. + + .. code-block:: python + import pymc as pm + + idx = pm.Bernoulli.dist(shape=(3, 2)) + dep1 = pm.Normal.dist(mu=idx.T[..., None] * 2, shape=(3, 2, 5)) + dep2 = pm.Normal.dist(mu=idx * 2, shape=(7, 2, 3)) + # dep3 = pm.Normal.dist(mu=idx.sum()) # Would raise if requested + + print(subgraph_batch_dim_connection(idx, [], [dep1, dep2])) + # [(1, 0, None), (None, 0, 1)] + + Returns: + ------- + list of tuples + Each tuple corresponds to the batch dimensions of the output_rv in the order they are found in the output. + None is used to indicate a batch dimension that is not mapped from the input. + + Raises: + ------ + ValueError + If input batch dimensions are mixed in the graph leading to output_vars. + + NotImplementedError + If variable related to marginalized batch_dims is used in an operation that is not yet supported + """ + var_dims = {input_var: tuple(range(input_var.type.ndim))} + var_dims = _subgraph_batch_dim_connection(var_dims, [input_var], output_vars) + ret = [] + for output_var in output_vars: + output_dims = var_dims.get(output_var, (None,) * output_var.type.ndim) + assert len(output_dims) == output_var.type.ndim + ret.append(output_dims) + return ret diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal/marginal_model.py similarity index 59% rename from pymc_experimental/model/marginal_model.py rename to pymc_experimental/model/marginal/marginal_model.py index c594e8ac..b4700c3d 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal/marginal_model.py @@ -8,30 +8,34 @@ import pytensor.tensor as pt from arviz import InferenceData, dict_to_dataset -from pymc import SymbolicRandomVariable from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain -from pymc.logprob.abstract import _logprob -from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Model from pymc.pytensorf import compile_pymc, constant_fold from pymc.util import RandomState, _get_seeds_per_chain, treedict -from pytensor import Mode, scan from pytensor.compile import SharedVariable -from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace -from pytensor.graph.basic import graph_inputs -from pytensor.graph.replace import graph_replace, vectorize_graph -from pytensor.scan import map as scan_map -from pytensor.tensor import TensorType, TensorVariable -from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.shape import Shape +from pytensor.graph import FunctionGraph, clone_replace, graph_inputs +from pytensor.graph.replace import vectorize_graph +from pytensor.tensor import TensorVariable from pytensor.tensor.special import log_softmax __all__ = ["MarginalModel", "marginalize"] from pymc_experimental.distributions import DiscreteMarkovChain +from pymc_experimental.model.marginal.distributions import ( + MarginalDiscreteMarkovChainRV, + MarginalFiniteDiscreteRV, + get_domain_of_finite_discrete_rv, + reduce_batch_dependent_logps, +) +from pymc_experimental.model.marginal.graph_analysis import ( + find_conditional_dependent_rvs, + find_conditional_input_rvs, + is_conditional_dependent, + subgraph_batch_dim_connection, +) ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str] @@ -419,17 +423,22 @@ def transform_input(inputs): m = self.clone() marginalized_rv = m.vars_to_clone[marginalized_rv] m.unmarginalize([marginalized_rv]) - dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) - joint_logps = m.logp(vars=[marginalized_rv, *dependent_vars], sum=False) - - marginalized_value = m.rvs_to_values[marginalized_rv] - other_values = [v for v in m.value_vars if v is not marginalized_value] + dependent_rvs = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) + logps = m.logp(vars=[marginalized_rv, *dependent_rvs], sum=False) # Handle batch dims for marginalized value and its dependent RVs - marginalized_logp, *dependent_logps = joint_logps - joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( - marginalized_rv.type, dependent_logps + dependent_rvs_dim_connections = subgraph_batch_dim_connection( + marginalized_rv, dependent_rvs ) + marginalized_logp, *dependent_logps = logps + joint_logp = marginalized_logp + reduce_batch_dependent_logps( + dependent_rvs_dim_connections, + [dependent_var.owner.op for dependent_var in dependent_rvs], + dependent_logps, + ) + + marginalized_value = m.rvs_to_values[marginalized_rv] + other_values = [v for v in m.value_vars if v is not marginalized_value] rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) @@ -443,37 +452,30 @@ def transform_input(inputs): 0, ) - joint_logps = vectorize_graph( + batched_joint_logp = vectorize_graph( joint_logp, replace={marginalized_value: rv_domain_tensor}, ) - joint_logps = pt.moveaxis(joint_logps, 0, -1) + batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1) - rv_loglike_fn = None - joint_logps_norm = log_softmax(joint_logps, axis=-1) + joint_logp_norm = log_softmax(batched_joint_logp, axis=-1) if return_samples: - sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps) + rv_draws = pymc.Categorical.dist(logit_p=batched_joint_logp) if isinstance(marginalized_rv.owner.op, DiscreteUniform): - sample_rv_outs += rv_domain[0] - - rv_loglike_fn = compile_pymc( - inputs=other_values, - outputs=[joint_logps_norm, sample_rv_outs], - on_unused_input="ignore", - random_seed=seed, - ) + rv_draws += rv_domain[0] + outputs = [joint_logp_norm, rv_draws] else: - rv_loglike_fn = compile_pymc( - inputs=other_values, - outputs=joint_logps_norm, - on_unused_input="ignore", - random_seed=seed, - ) + outputs = joint_logp_norm + + rv_loglike_fn = compile_pymc( + inputs=other_values, + outputs=outputs, + on_unused_input="ignore", + random_seed=seed, + ) logvs = [rv_loglike_fn(**vs) for vs in posterior_pts] - logps = None - samples = None if return_samples: logps, samples = zip(*logvs) logps = np.array(logps) @@ -540,105 +542,6 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: return marginal_model -class MarginalRV(SymbolicRandomVariable): - """Base class for Marginalized RVs""" - - -class FiniteDiscreteMarginalRV(MarginalRV): - """Base class for Finite Discrete Marginalized RVs""" - - -class DiscreteMarginalMarkovChainRV(MarginalRV): - """Base class for Discrete Marginal Markov Chain RVs""" - - -def static_shape_ancestors(vars): - """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" - return [ - var - for var in ancestors(vars) - if ( - var.owner - and isinstance(var.owner.op, Shape) - # All static dims lengths of Shape input are known - and None not in var.owner.inputs[0].type.shape - ) - ] - - -def find_conditional_input_rvs(output_rvs, all_rvs): - """Find conditionally indepedent input RVs.""" - blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] - blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs)) - return [ - var - for var in ancestors(output_rvs, blockers=blockers) - if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable)) - ] - - -def is_conditional_dependent( - dependent_rv: TensorVariable, dependable_rv: TensorVariable, all_rvs -) -> bool: - """Check if dependent_rv is conditionall dependent on dependable_rv, - given all conditionally independent all_rvs""" - - return dependable_rv in find_conditional_input_rvs((dependent_rv,), all_rvs) - - -def find_conditional_dependent_rvs(dependable_rv, all_rvs): - """Find rvs than depend on dependable""" - return [ - rv - for rv in all_rvs - if (rv is not dependable_rv and is_conditional_dependent(rv, dependable_rv, all_rvs)) - ] - - -def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): - # TODO: No need to consider apply nodes outside the subgraph... - fg = FunctionGraph(outputs=output_rvs, clone=False) - - non_elemwise_blockers = [ - o - for node in fg.apply_nodes - if not ( - isinstance(node.op, Elemwise) - # Allow expand_dims on the left - or ( - isinstance(node.op, DimShuffle) - and not node.op.drop - and node.op.shuffle == sorted(node.op.shuffle) - ) - ) - for o in node.outputs - ] - blocker_candidates = [rv_to_marginalize, *other_input_rvs, *non_elemwise_blockers] - blockers = [var for var in blocker_candidates if var not in output_rvs] - - truncated_inputs = [ - var - for var in ancestors(output_rvs, blockers=blockers) - if ( - var in blockers - or (var.owner is None and not isinstance(var, Constant | SharedVariable)) - ) - ] - - # Check that we reach the marginalized rv following a pure elemwise graph - if rv_to_marginalize not in truncated_inputs: - return False - - # Check that none of the truncated inputs depends on the marginalized_rv - other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize] - # TODO: We don't need to go all the way to the root variables - if rv_to_marginalize in ancestors( - other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs] - ): - return False - return True - - def collect_shared_vars(outputs, blockers): return [ inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable) @@ -646,225 +549,47 @@ def collect_shared_vars(outputs, blockers): def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): - # TODO: This should eventually be integrated in a more general routine that can - # identify other types of supported marginalization, of which finite discrete - # RVs is just one - dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) if not dependent_rvs: raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") - ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs} - if len(ndim_supp) != 1: - raise NotImplementedError( - "Marginalization with dependent variables of different support dimensionality not implemented" - ) - [ndim_supp] = ndim_supp - if ndim_supp > 0: - raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented") - marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) - dependent_rvs_input_rvs = [ + other_direct_rv_ancestors = [ rv for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) if rv is not rv_to_marginalize ] - # If the marginalized RV has batched dimensions, check that graph between - # marginalized RV and dependent RVs is composed strictly of Elemwise Operations. - # This implies (?) that the dimensions are completely independent and a logp graph - # can ultimately be generated that is proportional to the support domain and not - # to the variables dimensions - # We don't need to worry about this if the RV is scalar. - if np.prod(constant_fold(tuple(rv_to_marginalize.shape), raise_not_constant=False)) != 1: - if not is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs): - raise NotImplementedError( - "The subgraph between a marginalized RV and its dependents includes non Elemwise operations. " - "This is currently not supported", - ) + # If the marginalized RV has multiple dimensions, check that graph between + # marginalized RV and dependent RVs does not mix information from batch dimensions + # (otherwise logp would require enumerating over all combinations of batch dimension values) + try: + dependent_rvs_dim_connections = subgraph_batch_dim_connection( + rv_to_marginalize, dependent_rvs + ) + except (ValueError, NotImplementedError) as e: + # For the perspective of the user this is a NotImplementedError + raise NotImplementedError( + "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. " + "You can try splitting the marginalized RV into separate components and marginalizing them separately." + ) from e - input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs] - rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] + input_rvs = list(set((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))) + output_rvs = [rv_to_marginalize, *dependent_rvs] - outputs = rvs_to_marginalize # We are strict about shared variables in SymbolicRandomVariables - inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs) + inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs) if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): - marginalize_constructor = DiscreteMarginalMarkovChainRV + marginalize_constructor = MarginalDiscreteMarkovChainRV else: - marginalize_constructor = FiniteDiscreteMarginalRV + marginalize_constructor = MarginalFiniteDiscreteRV marginalization_op = marginalize_constructor( inputs=inputs, - outputs=outputs, - ndim_supp=ndim_supp, - ) - - marginalized_rvs = marginalization_op(*inputs) - fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) - return rvs_to_marginalize, marginalized_rvs - - -def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: - op = rv.owner.op - dist_params = rv.owner.op.dist_params(rv.owner) - if isinstance(op, Bernoulli): - return (0, 1) - elif isinstance(op, Categorical): - [p_param] = dist_params - return tuple(range(pt.get_vector_length(p_param))) - elif isinstance(op, DiscreteUniform): - lower, upper = constant_fold(dist_params) - return tuple(np.arange(lower, upper + 1)) - elif isinstance(op, DiscreteMarkovChain): - P, *_ = dist_params - return tuple(range(pt.get_vector_length(P[-1]))) - - raise NotImplementedError(f"Cannot compute domain for op {op}") - - -def _add_reduce_batch_dependent_logps( - marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] -): - """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`.""" - - mbcast = marginalized_type.broadcastable - reduced_logps = [] - for dependent_logp in dependent_logps: - dbcast = dependent_logp.type.broadcastable - dim_diff = len(dbcast) - len(mbcast) - mbcast_aligned = (True,) * dim_diff + mbcast - vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v] - reduced_logps.append(dependent_logp.sum(vbcast_axis)) - return pt.add(*reduced_logps) - - -@_logprob.register(FiniteDiscreteMarginalRV) -def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): - # Clone the inner RV graph of the Marginalized RV - marginalized_rvs_node = op.make_node(*inputs) - marginalized_rv, *inner_rvs = clone_replace( - op.inner_outputs, - replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, - ) - - # Obtain the joint_logp graph of the inner RV graph - inner_rv_values = dict(zip(inner_rvs, values)) - marginalized_vv = marginalized_rv.clone() - rv_values = inner_rv_values | {marginalized_rv: marginalized_vv} - logps_dict = conditional_logp(rv_values=rv_values, **kwargs) - - # Reduce logp dimensions corresponding to broadcasted variables - marginalized_logp = logps_dict.pop(marginalized_vv) - joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( - marginalized_rv.type, logps_dict.values() + outputs=output_rvs, # TODO: Add RNG updates to outputs so this can be used in the generative graph + dims_connections=dependent_rvs_dim_connections, ) - - # Compute the joint_logp for all possible n values of the marginalized RV. We assume - # each original dimension is independent so that it suffices to evaluate the graph - # n times, once with each possible value of the marginalized RV replicated across - # batched dimensions of the marginalized RV - - # PyMC does not allow RVs in the logp graph, even if we are just using the shape - marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) - marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) - marginalized_rv_domain_tensor = pt.moveaxis( - pt.full( - (*marginalized_rv_shape, len(marginalized_rv_domain)), - marginalized_rv_domain, - dtype=marginalized_rv.dtype, - ), - -1, - 0, - ) - - try: - joint_logps = vectorize_graph( - joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor} - ) - except Exception: - # Fallback to Scan - def logp_fn(marginalized_rv_const, *non_sequences): - return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const}) - - joint_logps, _ = scan_map( - fn=logp_fn, - sequences=marginalized_rv_domain_tensor, - non_sequences=[*values, *inputs], - mode=Mode().including("local_remove_check_parameter"), - ) - - joint_logps = pt.logsumexp(joint_logps, axis=0) - - # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise - return joint_logps, *(pt.constant(0),) * (len(values) - 1) - - -@_logprob.register(DiscreteMarginalMarkovChainRV) -def marginal_hmm_logp(op, values, *inputs, **kwargs): - marginalized_rvs_node = op.make_node(*inputs) - inner_rvs = clone_replace( - op.inner_outputs, - replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, - ) - - chain_rv, *dependent_rvs = inner_rvs - P, n_steps_, init_dist_, rng = chain_rv.owner.inputs - domain = pt.arange(P.shape[-1], dtype="int32") - - # Construct logp in two steps - # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission) - - # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating - # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise, - # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step. - chain_value = chain_rv.clone() - dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value}) - logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) - - # Reduce and add the batch dims beyond the chain dimension - reduced_logp_emissions = _add_reduce_batch_dependent_logps( - chain_rv.type, logp_emissions_dict.values() - ) - - # Add a batch dimension for the domain of the chain - chain_shape = constant_fold(tuple(chain_rv.shape)) - batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) - batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) - - # Step 2: Compute the transition probabilities - # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) - # We do it entirely in logs, though. - - # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) - # under the initial distribution. This is robust to everything the user can throw at it. - init_dist_value = init_dist_.type() - logp_init_dist = logp(init_dist_, init_dist_value) - # There is a degerate batch dim for lags=1 (the only supported case), - # that we have to work around, by expanding the batch value and then squeezing it out of the logp - batch_logp_init_dist = vectorize_graph( - logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]} - ).squeeze(1) - log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0] - - def step_alpha(logp_emission, log_alpha, log_P): - step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0) - return logp_emission + step_log_prob - - P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2) - log_P = pt.shape_padright(pt.log(P), P_bcast_dims) - log_alpha_seq, _ = scan( - step_alpha, - non_sequences=[log_P], - outputs_info=[log_alpha_init], - # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value - sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0), - ) - # Final logp is just the sum of the last scan state - joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) - - # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first - # return is the joint probability of everything together, but PyMC still expects one logp for each one. - dummy_logps = (pt.constant(0),) * (len(values) - 1) - return joint_logp, *dummy_logps + new_output_rvs = marginalization_op(*inputs) + fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs))) + return output_rvs, new_output_rvs diff --git a/requirements.txt b/requirements.txt index a7141a82..b992ad37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.16.1 +pymc>=5.17.0 scikit-learn diff --git a/tests/model/marginal/__init__.py b/tests/model/marginal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/model/marginal/test_distributions.py b/tests/model/marginal/test_distributions.py new file mode 100644 index 00000000..ecbc8817 --- /dev/null +++ b/tests/model/marginal/test_distributions.py @@ -0,0 +1,131 @@ +import numpy as np +import pymc as pm +import pytest + +from pymc.logprob.abstract import _logprob +from pytensor import tensor as pt +from scipy.stats import norm + +from pymc_experimental import MarginalModel +from pymc_experimental.distributions import DiscreteMarkovChain +from pymc_experimental.model.marginal.distributions import MarginalFiniteDiscreteRV + + +def test_marginalized_bernoulli_logp(): + """Test logp of IR TestFiniteMarginalDiscreteRV directly""" + mu = pt.vector("mu") + + idx = pm.Bernoulli.dist(0.7, name="idx") + y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y") + marginal_rv_node = MarginalFiniteDiscreteRV( + [mu], + [idx, y], + dims_connections=(((),),), + )(mu)[0].owner + + y_vv = y.clone() + (logp,) = _logprob( + marginal_rv_node.op, + (y_vv,), + *marginal_rv_node.inputs, + ) + + ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=mu, sigma=1.0), y_vv) + np.testing.assert_almost_equal( + logp.eval({mu: [-1, 1], y_vv: 2}), + ref_logp.eval({mu: [-1, 1], y_vv: 2}), + ) + + +@pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}") +@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}") +def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): + if batch_chain and not batch_emission: + pytest.skip("Redundant implicit combination") + + with MarginalModel() as m: + P = [[0, 1], [1, 0]] + init_dist = pm.Categorical.dist(p=[1, 0]) + chain = DiscreteMarkovChain( + "chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None + ) + emission = pm.Normal( + "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None + ) + + m.marginalize([chain]) + logp_fn = m.compile_logp() + + test_value = np.array([-1, 1, -1, 1]) + expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval() + if batch_emission: + test_value = np.broadcast_to(test_value, (3, 4)) + expected_logp *= 3 + np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) + + +@pytest.mark.parametrize( + "categorical_emission", + [False, True], +) +def test_marginalized_hmm_categorical_emission(categorical_emission): + """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0""" + with MarginalModel() as m: + P = np.array([[0.5, 0.5], [0.3, 0.7]]) + init_dist = pm.Categorical.dist(p=[0.375, 0.625]) + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) + if categorical_emission: + emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain]) + else: + emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) + m.marginalize([chain]) + + test_value = np.array([0, 0, 1]) + expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video + logp_fn = m.compile_logp() + np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) + + +@pytest.mark.parametrize("batch_chain", (False, True)) +@pytest.mark.parametrize("batch_emission1", (False, True)) +@pytest.mark.parametrize("batch_emission2", (False, True)) +def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch_emission2): + chain_shape = (3, 1, 4) if batch_chain else (4,) + emission1_shape = ( + (2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape)) + ) + emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape + with MarginalModel() as m: + P = [[0, 1], [1, 0]] + init_dist = pm.Categorical.dist(p=[1, 0]) + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape) + emission_1 = pm.Normal( + "emission_1", mu=(chain * 2 - 1).T, sigma=1e-1, shape=emission1_shape + ) + + emission2_mu = (1 - chain) * 2 - 1 + if batch_emission2: + emission2_mu = emission2_mu[..., None] + emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape) + + with pytest.warns(UserWarning, match="multiple dependent variables"): + m.marginalize([chain]) + + logp_fn = m.compile_logp(sum=False) + + test_value = np.array([-1, 1, -1, 1]) + multiplier = 2 + batch_emission1 + batch_emission2 + if batch_chain: + multiplier *= 3 + expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier + + test_value = np.broadcast_to(test_value, chain_shape) + test_value_emission1 = np.broadcast_to(test_value.T, emission1_shape) + if batch_emission2: + test_value_emission2 = np.broadcast_to(-test_value[..., None], emission2_shape) + else: + test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) + test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} + res_logp, dummy_logp = logp_fn(test_point) + assert res_logp.shape == ((1, 3) if batch_chain else ()) + np.testing.assert_allclose(res_logp.sum(), expected_logp) diff --git a/tests/model/marginal/test_graph_analysis.py b/tests/model/marginal/test_graph_analysis.py new file mode 100644 index 00000000..2382247b --- /dev/null +++ b/tests/model/marginal/test_graph_analysis.py @@ -0,0 +1,182 @@ +import pytensor.tensor as pt +import pytest + +from pymc.distributions import CustomDist +from pytensor.tensor.type_other import NoneTypeT + +from pymc_experimental.model.marginal.graph_analysis import ( + is_conditional_dependent, + subgraph_batch_dim_connection, +) + + +def test_is_conditional_dependent_static_shape(): + """Test that we don't consider dependencies through "constant" shape Ops""" + x1 = pt.matrix("x1", shape=(None, 5)) + y1 = pt.random.normal(size=pt.shape(x1)) + assert is_conditional_dependent(y1, x1, [x1, y1]) + + x2 = pt.matrix("x2", shape=(9, 5)) + y2 = pt.random.normal(size=pt.shape(x2)) + assert not is_conditional_dependent(y2, x2, [x2, y2]) + + +class TestSubgraphBatchDimConnection: + def test_dimshuffle(self): + inp = pt.tensor(shape=(5, 1, 4, 3)) + out1 = pt.matrix_transpose(inp) + out2 = pt.expand_dims(inp, 1) + out3 = pt.squeeze(inp) + [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3]) + assert dims1 == (0, 1, 3, 2) + assert dims2 == (0, None, 1, 2, 3) + assert dims3 == (0, 2, 3) + + def test_careduce(self): + inp = pt.tensor(shape=(4, 3, 2)) + + out = pt.sum(inp[:, None], axis=(1,)) + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, 2) + + invalid_out = pt.sum(inp, axis=(1,)) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + def test_subtensor(self): + inp = pt.tensor(shape=(4, 3, 2)) + + invalid_out = inp[0, :1] + with pytest.raises( + ValueError, + match="Partial slicing or indexing of known dimensions not supported", + ): + subgraph_batch_dim_connection(inp, [invalid_out]) + + # If we are selecting dummy / unknown dimensions that's fine + valid_out = pt.expand_dims(inp, (0, 1))[0, :1] + [dims] = subgraph_batch_dim_connection(inp, [valid_out]) + assert dims == (None, 0, 1, 2) + + def test_advanced_subtensor_value(self): + inp = pt.tensor(shape=(2, 4)) + intermediate_out = inp[:, None, :, None] + pt.zeros((2, 3, 4, 5)) + + # Index on an unlabled dim introduced by broadcasting with zeros + out = intermediate_out[:, [0, 0, 1, 2]] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, None, 1, None) + + # Indexing that introduces more dimensions + out = intermediate_out[:, [[0, 0], [1, 2]], :] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, None, None, 1, None) + + # Special case where advanced dims are moved to the front of the output + out = intermediate_out[:, [0, 0, 1, 2], :, 0] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (None, 0, 1) + + # Indexing on a labeled dim fails + out = intermediate_out[:, :, [0, 0, 1, 2]] + with pytest.raises(ValueError, match="Partial slicing or advanced integer indexing"): + subgraph_batch_dim_connection(inp, [out]) + + def test_advanced_subtensor_key(self): + inp = pt.tensor(shape=(5, 5), dtype=int) + base = pt.zeros((2, 3, 4)) + + out = base[inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, None, None) + + out = base[:, :, inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == ( + None, + None, + 0, + 1, + ) + + out = base[1:, 0, inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (None, 0, 1) + + # Special case where advanced dims are moved to the front of the output + out = base[0, :, inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, None) + + # Mix keys dimensions + out = base[:, inp, inp.T] + with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"): + subgraph_batch_dim_connection(inp, [out]) + + def test_elemwise(self): + inp = pt.tensor(shape=(5, 5)) + + out = inp + inp + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1) + + out = inp + inp.T + with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"): + subgraph_batch_dim_connection(inp, [out]) + + out = inp[None, :, None, :] + inp[:, None, :, None] + with pytest.raises( + ValueError, match="Same known dimension used in different axis after broadcasting" + ): + subgraph_batch_dim_connection(inp, [out]) + + def test_blockwise(self): + inp = pt.tensor(shape=(5, 4)) + + invalid_out = inp @ pt.ones((4, 3)) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((2, 3)) + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, None, None) + + def test_random_variable(self): + inp = pt.tensor(shape=(5, 4, 3)) + + out1 = pt.random.normal(loc=inp) + out2 = pt.random.categorical(p=inp[..., None]) + out3 = pt.random.multivariate_normal(mean=inp[..., None], cov=pt.eye(1)) + [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3]) + assert dims1 == (0, 1, 2) + assert dims2 == (0, 1, 2) + assert dims3 == (0, 1, 2, None) + + invalid_out = pt.random.categorical(p=inp) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + invalid_out = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3)) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + def test_symbolic_random_variable(self): + inp = pt.tensor(shape=(4, 3, 2)) + + # Test univariate + out = CustomDist.dist( + inp, + dist=lambda mu, size: pt.random.normal(loc=mu, size=size), + ) + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, 2) + + # Test multivariate + def dist(mu, size): + if isinstance(size.type, NoneTypeT): + size = mu.shape + return pt.random.normal(loc=mu[..., None], size=(*size, 2)) + + out = CustomDist.dist(inp, dist=dist, size=(4, 3, 2), signature="()->(2)") + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, 2, None) diff --git a/tests/model/marginal/test_marginal_model.py b/tests/model/marginal/test_marginal_model.py new file mode 100644 index 00000000..c93cdb74 --- /dev/null +++ b/tests/model/marginal/test_marginal_model.py @@ -0,0 +1,867 @@ +import itertools + +from contextlib import suppress as does_not_warn + +import numpy as np +import pandas as pd +import pymc as pm +import pytensor.tensor as pt +import pytest + +from arviz import InferenceData, dict_to_dataset +from pymc.distributions import transforms +from pymc.distributions.transforms import ordered +from pymc.model.fgraph import fgraph_from_model +from pymc.pytensorf import inputvars +from pymc.util import UNSET +from scipy.special import log_softmax, logsumexp +from scipy.stats import halfnorm, norm + +from pymc_experimental.model.marginal.marginal_model import ( + MarginalModel, + marginalize, +) +from tests.utils import equal_computations_up_to_root + + +def test_basic_marginalized_rv(): + data = [2] * 5 + + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6]) + mu = pt.switch( + pt.eq(idx, 0), + -1.0, + pt.switch( + pt.eq(idx, 1), + 0.0, + 1.0, + ), + ) + y = pm.Normal("y", mu=mu, sigma=sigma) + z = pm.Normal("z", y, observed=data) + + m.marginalize([idx]) + assert idx not in m.free_RVs + assert [rv.name for rv in m.marginalized_rvs] == ["idx"] + + # Test logp + with pm.Model() as m_ref: + sigma = pm.HalfNormal("sigma") + y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma) + z = pm.Normal("z", y, observed=data) + + test_point = m_ref.initial_point() + ref_logp = m_ref.compile_logp()(test_point) + ref_dlogp = m_ref.compile_dlogp([m_ref["y"]])(test_point) + + # Assert we can marginalize and unmarginalize internally non-destructively + for i in range(3): + np.testing.assert_almost_equal( + m.compile_logp()(test_point), + ref_logp, + ) + np.testing.assert_almost_equal( + m.compile_dlogp([m["y"]])(test_point), + ref_dlogp, + ) + + +def test_one_to_one_marginalized_rvs(): + """Test case with multiple, independent marginalized RVs.""" + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + idx1 = pm.Bernoulli("idx1", p=0.75) + x = pm.Normal("x", mu=idx1, sigma=sigma) + idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,)) + y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,)) + + m.marginalize([idx1, idx2]) + m["x"].owner is not m["y"].owner + _m = m.clone()._marginalize() + _m["x"].owner is not _m["y"].owner + + with pm.Model() as m_ref: + sigma = pm.HalfNormal("sigma") + x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma) + y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,)) + + # Test logp + test_point = m_ref.initial_point() + x_logp, y_logp = m.compile_logp(vars=[m["x"], m["y"]], sum=False)(test_point) + x_ref_log, y_ref_logp = m_ref.compile_logp(vars=[m_ref["x"], m_ref["y"]], sum=False)(test_point) + np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum()) + np.testing.assert_array_almost_equal(y_logp, y_ref_logp) + + +def test_one_to_many_marginalized_rvs(): + """Test that marginalization works when there is more than one dependent RV""" + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Bernoulli("idx", p=0.75) + x = pm.Normal("x", mu=idx, sigma=sigma) + y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,)) + + ref_logp_x_y_fn = m.compile_logp([idx, x, y]) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize([idx]) + + m["x"].owner is not m["y"].owner + _m = m.clone()._marginalize() + _m["x"].owner is _m["y"].owner + + tp = m.initial_point() + ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)]) + logp_x_y = m.compile_logp([x, y])(tp) + np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y) + + +def test_one_to_many_unaligned_marginalized_rvs(): + """Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned""" + + def build_model(build_batched: bool): + with MarginalModel() as m: + if build_batched: + idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2)) + else: + idxs = [pm.Bernoulli(f"idx_{i}", p=(0.75 if i % 2 == 0 else 0.4)) for i in range(6)] + idx = pt.stack(idxs, axis=0).reshape((3, 2)) + + x = pm.Normal("x", mu=idx.T[:, :, None], shape=(2, 3, 1)) + y = pm.Normal("y", mu=(idx * 2 - 1), shape=(1, 3, 2)) + + return m + + m = build_model(build_batched=True) + ref_m = build_model(build_batched=False) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize(["idx"]) + ref_m.marginalize([f"idx_{i}" for i in range(6)]) + + test_point = m.initial_point() + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + +def test_many_to_one_marginalized_rvs(): + """Test when random variables depend on multiple marginalized variables""" + with MarginalModel() as m: + x = pm.Bernoulli("x", 0.1) + y = pm.Bernoulli("y", 0.3) + z = pm.DiracDelta("z", c=x + y) + + m.marginalize([x, y]) + logp = m.compile_logp() + + np.testing.assert_allclose(np.exp(logp({"z": 0})), 0.9 * 0.7) + np.testing.assert_allclose(np.exp(logp({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7) + np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3) + + +@pytest.mark.parametrize("batched", (False, "left", "right")) +def test_nested_marginalized_rvs(batched): + """Test that marginalization works when there are nested marginalized RVs""" + + def build_model(build_batched: bool) -> MarginalModel: + idx_shape = (3,) if build_batched else () + sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5) + + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + + idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape) + dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma) + + sub_idx_p = pt.switch(pt.eq(idx, 0), 0.15, 0.95) + if build_batched and batched == "right": + sub_idx_p = sub_idx_p[..., None] + dep = dep[..., None] + sub_idx = pm.Bernoulli("sub_idx", p=sub_idx_p, shape=sub_idx_shape) + sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma) + + return m + + m = build_model(build_batched=batched) + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize(["idx", "sub_idx"]) + assert sorted(m.name for m in m.marginalized_rvs) == ["idx", "sub_idx"] + + # Test logp + ref_m = build_model(build_batched=False) + ref_logp_fn = ref_m.compile_logp( + vars=[ref_m["idx"], ref_m["dep"], ref_m["sub_idx"], ref_m["sub_dep"]] + ) + + test_point = ref_m.initial_point() + test_point["dep"] = np.full_like(test_point["dep"], 1000) + test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100) + ref_logp = logsumexp( + [ + ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}}) + for idx in (0, 1) + for sub_idxs in itertools.product((0, 1), repeat=5) + ] + ) + if batched: + ref_logp *= 3 + + test_point = m.initial_point() + test_point["dep"] = np.full_like(test_point["dep"], 1000) + test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100) + logp = m.compile_logp(vars=[m["dep"], m["sub_dep"]])(test_point) + + np.testing.assert_almost_equal(logp, ref_logp) + + +@pytest.mark.parametrize("advanced_indexing", (False, True)) +def test_marginalized_index_as_key(advanced_indexing): + """Test we can marginalize graphs where indexing is used as a mapping.""" + + w = [0.1, 0.3, 0.6] + mu = pt.as_tensor([-1, 0, 1]) + + if advanced_indexing: + y_val = pt.as_tensor([[-1, -1], [0, 1]]) + shape = (2, 2) + else: + y_val = -1 + shape = () + + with MarginalModel() as m: + x = pm.Categorical("x", p=w, shape=shape) + y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val) + + m.marginalize(x) + + marginal_logp = m.compile_logp(sum=False)({})[0] + ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval() + + np.testing.assert_allclose(marginal_logp, ref_logp) + + +def test_marginalized_index_as_value_and_key(): + """Test we can marginalize graphs were marginalized_rv is indexed.""" + + def build_model(build_batched: bool) -> MarginalModel: + with MarginalModel() as m: + if build_batched: + latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,)) + else: + latent_state = pm.math.stack( + [pm.Bernoulli(f"latent_state_{i}", p=0.3) for i in range(4)] + ) + # latent state is used as the indexed variable + latent_intensities = pt.where(latent_state[:, None], [0.0, 1.0, 2.0], [0.0, 10.0, 20.0]) + picked_intensity = pm.Categorical("picked_intensity", p=[0.2, 0.2, 0.6]) + # picked intensity is used as the indexing variable + pm.Normal( + "intensity", + mu=latent_intensities[:, picked_intensity], + observed=[0.5, 1.5, 5.0, 15.0], + ) + return m + + # We compare with the equivalent but less efficient batched model + m = build_model(build_batched=True) + ref_m = build_model(build_batched=False) + + m.marginalize(["latent_state"]) + ref_m.marginalize([f"latent_state_{i}" for i in range(4)]) + test_point = {"picked_intensity": 1} + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + m.marginalize(["picked_intensity"]) + ref_m.marginalize(["picked_intensity"]) + test_point = {} + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + +class TestNotSupportedMixedDims: + """Test lack of support for models where batch dims of marginalized variables are mixed.""" + + def test_mixed_dims_via_transposed_dot(self): + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=idx @ idx.T) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + def test_mixed_dims_via_indexing(self): + mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]]) + + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.7, shape=2) + mu = pt.specify_broadcastable(mean[:, None][idx], 1) + pt.specify_broadcastable( + mean[None, :][:, idx], 0 + ) + y = pm.Normal("y", mu=mu) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=idx[0] + idx[1]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + def test_mixed_dims_via_vector_indexing(self): + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.7, shape=2) + y = pm.Normal("y", mu=idx[[0, 1, 0, 0]]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + with MarginalModel() as m: + idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2)) + y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + def test_mixed_dims_via_support_dimension(self): + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7, shape=3) + y = pm.Dirichlet("y", a=x * 10 + 1) + with pytest.raises(NotImplementedError): + m.marginalize(x) + + def test_mixed_dims_via_nested_marginalization(self): + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7, shape=(3,)) + y = pm.Bernoulli("y", p=0.7, shape=(2,)) + z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2)) + + with pytest.raises(NotImplementedError): + m.marginalize([x, y]) + + +def test_marginalized_deterministic_and_potential(): + rng = np.random.default_rng(299) + + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7) + y = pm.Normal("y", x) + z = pm.Normal("z", x) + det = pm.Deterministic("det", y + z) + pot = pm.Potential("pot", y + z + 1) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize([x]) + + y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng) + np.testing.assert_almost_equal(y_draw + z_draw, det_draw) + np.testing.assert_almost_equal(det_draw, pot_draw - 1) + + y_value = m.rvs_to_values[y] + z_value = m.rvs_to_values[z] + det_value, pot_value = m.replace_rvs_by_values([det, pot]) + assert set(inputvars([det_value, pot_value])) == {y_value, z_value} + assert det_value.eval({y_value: 2, z_value: 5}) == 7 + assert pot_value.eval({y_value: 2, z_value: 5}) == 8 + + +def test_not_supported_marginalized_deterministic_and_potential(): + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7) + y = pm.Normal("y", x) + det = pm.Deterministic("det", x + y) + + with pytest.raises( + NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det" + ): + m.marginalize([x]) + + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7) + y = pm.Normal("y", x) + pot = pm.Potential("pot", x + y) + + with pytest.raises( + NotImplementedError, match="Cannot marginalize x due to dependent Potential pot" + ): + m.marginalize([x]) + + +@pytest.mark.parametrize( + "transform, expected_warning", + ( + (None, does_not_warn()), + (UNSET, does_not_warn()), + (transforms.log, does_not_warn()), + (transforms.Chain([transforms.log, transforms.logodds]), does_not_warn()), + ( + transforms.Interval(0, 1), + pytest.warns( + UserWarning, match="which depends on the marginalized idx may no longer work" + ), + ), + ( + transforms.Chain([transforms.log, transforms.Interval(0, 1)]), + pytest.warns( + UserWarning, match="which depends on the marginalized idx may no longer work" + ), + ), + ), +) +def test_marginalized_transforms(transform, expected_warning): + w = [0.1, 0.3, 0.6] + data = [0, 5, 10] + initval = 0.5 # Value that will be negative on the unconstrained space + + with pm.Model() as m_ref: + sigma = pm.Mixture( + "sigma", + w=w, + comp_dists=pm.HalfNormal.dist([1, 2, 3]), + initval=initval, + default_transform=transform, + ) + y = pm.Normal("y", 0, sigma, observed=data) + + with MarginalModel() as m: + idx = pm.Categorical("idx", p=w) + sigma = pm.HalfNormal( + "sigma", + pt.switch( + pt.eq(idx, 0), + 1, + pt.switch( + pt.eq(idx, 1), + 2, + 3, + ), + ), + initval=initval, + default_transform=transform, + ) + y = pm.Normal("y", 0, sigma, observed=data) + + with expected_warning: + m.marginalize([idx]) + + ip = m.initial_point() + if transform is not None: + if transform is UNSET: + transform_name = "log" + else: + transform_name = transform.name + assert f"sigma_{transform_name}__" in ip + np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip)) + + +def test_data_container(): + """Test that MarginalModel can handle Data containers.""" + with MarginalModel(coords={"obs": [0]}) as marginal_m: + x = pm.Data("x", 2.5) + idx = pm.Bernoulli("idx", p=0.7, dims="obs") + y = pm.Normal("y", idx * x, dims="obs") + + marginal_m.marginalize([idx]) + + logp_fn = marginal_m.compile_logp() + + with pm.Model(coords={"obs": [0]}) as m_ref: + x = pm.Data("x", 2.5) + y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs") + + ref_logp_fn = m_ref.compile_logp() + + for i, x_val in enumerate((-1.5, 2.5, 3.5), start=1): + for m in (marginal_m, m_ref): + m.set_dim("obs", new_length=i, coord_values=tuple(range(i))) + pm.set_data({"x": x_val}, model=m) + + ip = marginal_m.initial_point() + np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip)) + + +def test_mutable_indexing_jax_backend(): + pytest.importorskip("jax") + from pymc.sampling.jax import get_jaxified_logp + + with MarginalModel() as model: + data = pm.Data("data", np.zeros(10)) + + cat_effect = pm.Normal("cat_effect", sigma=1, shape=5) + cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5)) + + is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10) + pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data) + model.marginalize(["is_outlier"]) + get_jaxified_logp(model) + + +def test_marginal_model_func(): + def create_model(model_class): + with model_class(coords={"trial": range(10)}) as m: + idx = pm.Bernoulli("idx", p=0.5, dims="trial") + mu = pt.where(idx, 1, -1) + sigma = pm.HalfNormal("sigma") + y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10) + return m + + marginal_m = marginalize(create_model(pm.Model), ["idx"]) + assert isinstance(marginal_m, MarginalModel) + + reference_m = create_model(MarginalModel) + reference_m.marginalize(["idx"]) + + # Check forward graph representation is the same + marginal_fgraph, _ = fgraph_from_model(marginal_m) + reference_fgraph, _ = fgraph_from_model(reference_m) + assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs) + + # Check logp graph is the same + # This fails because OpFromGraphs comparison is broken + # assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()]) + ip = marginal_m.initial_point() + np.testing.assert_allclose( + marginal_m.compile_logp()(ip), + reference_m.compile_logp()(ip), + ) + + +class TestFullModels: + @pytest.fixture + def disaster_model(self): + # fmt: off + disaster_data = pd.Series( + [4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, + 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5, + 2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0, + 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, + 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, + 3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1] + ) + # fmt: on + years = np.arange(1851, 1962) + + with MarginalModel() as disaster_model: + switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max()) + early_rate = pm.Exponential("early_rate", 1.0, initval=3) + late_rate = pm.Exponential("late_rate", 1.0, initval=1) + rate = pm.math.switch(switchpoint >= years, early_rate, late_rate) + with pytest.warns(Warning): + disasters = pm.Poisson("disasters", rate, observed=disaster_data) + + return disaster_model, years + + def test_change_point_model(self, disaster_model): + m, years = disaster_model + + ip = m.initial_point() + ip.pop("switchpoint") + ref_logp_fn = m.compile_logp( + [m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]] + ) + ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years]) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize(m["switchpoint"]) + + logp = m.compile_logp([m["disasters_observed"], m["disasters_unobserved"]])(ip) + np.testing.assert_almost_equal(logp, ref_logp) + + @pytest.mark.slow + def test_change_point_model_sampling(self, disaster_model): + m, _ = disaster_model + + rng = np.random.default_rng(211) + + with m: + before_marg = pm.sample(chains=2, random_seed=rng).posterior.stack( + sample=("draw", "chain") + ) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize([m["switchpoint"]]) + + with m: + after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack( + sample=("draw", "chain") + ) + + np.testing.assert_allclose( + before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2 + ) + np.testing.assert_allclose( + before_marg["late_rate"].mean(), after_marg["late_rate"].mean(), rtol=1e-2 + ) + np.testing.assert_allclose( + before_marg["disasters_unobserved"].mean(), + after_marg["disasters_unobserved"].mean(), + rtol=1e-2, + ) + + @pytest.mark.parametrize("univariate", (True, False)) + def test_vector_univariate_mixture(self, univariate): + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ()) + + def dist(idx, size): + return pm.math.switch( + pm.math.eq(idx, 0), + pm.Normal.dist([-10, -10], 1), + pm.Normal.dist([10, 10], 1), + ) + + pm.CustomDist("norm", idx, dist=dist) + + m.marginalize(idx) + logp_fn = m.compile_logp() + + if univariate: + with pm.Model() as ref_m: + pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,)) + else: + with pm.Model() as ref_m: + pm.Mixture( + "norm", + w=[0.5, 0.5], + comp_dists=[ + pm.MvNormal.dist([-10, -10], np.eye(2)), + pm.MvNormal.dist([10, 10], np.eye(2)), + ], + shape=(2,), + ) + ref_logp_fn = ref_m.compile_logp() + + for test_value in ( + [-10, -10], + [10, 10], + [-10, 10], + [-10, 10], + ): + pt = {"norm": test_value} + np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt)) + + def test_k_censored_clusters_model(self): + def build_model(build_batched: bool) -> MarginalModel: + data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]]) + nobs = data.shape[0] + n_clusters = 5 + coords = { + "cluster": range(n_clusters), + "ndim": ("x", "y"), + "obs": range(nobs), + } + with MarginalModel(coords=coords) as m: + if build_batched: + idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"]) + else: + idx = pm.math.stack( + [ + pm.Categorical(f"idx_{i}", p=np.ones(n_clusters) / n_clusters) + for i in range(nobs) + ] + ) + + mu_x = pm.Normal( + "mu_x", + dims=["cluster"], + transform=ordered, + initval=np.linspace(-1, 1, n_clusters), + ) + mu_y = pm.Normal("mu_y", dims=["cluster"]) + mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim) + mu_indexed = mu[idx, :] + + sigma = pm.HalfNormal("sigma") + + y = pm.Censored( + "y", + dist=pm.Normal.dist(mu_indexed, sigma), + lower=-3, + upper=3, + observed=data, + dims=["obs", "ndim"], + ) + + return m + + m = build_model(build_batched=True) + ref_m = build_model(build_batched=False) + + m.marginalize([m["idx"]]) + ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")]) + + test_point = m.initial_point() + np.testing.assert_almost_equal( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + +class TestRecoverMarginals: + def test_basic(self): + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + p = np.array([0.5, 0.2, 0.3]) + k = pm.Categorical("k", p=p) + mu = np.array([-3.0, 0.0, 3.0]) + mu_ = pt.as_tensor_variable(mu) + y = pm.Normal("y", mu=mu_[k], sigma=sigma) + + m.marginalize([k]) + + rng = np.random.default_rng(211) + + with m: + prior = pm.sample_prior_predictive( + draws=20, + random_seed=rng, + return_inferencedata=False, + ) + idata = InferenceData(posterior=dict_to_dataset(prior)) + + idata = m.recover_marginals(idata, return_samples=True) + post = idata.posterior + assert "k" in post + assert "lp_k" in post + assert post.k.shape == post.y.shape + assert post.lp_k.shape == (*post.k.shape, len(p)) + + def true_logp(y, sigma): + y = y.repeat(len(p)).reshape(len(y), -1) + sigma = sigma.repeat(len(p)).reshape(len(sigma), -1) + return log_softmax( + np.log(p) + + norm.logpdf(y, loc=mu, scale=sigma) + + halfnorm.logpdf(sigma) + + np.log(sigma), + axis=1, + ) + + np.testing.assert_almost_equal( + true_logp(post.y.values.flatten(), post.sigma.values.flatten()), + post.lp_k[0].values, + ) + np.testing.assert_almost_equal(logsumexp(post.lp_k, axis=-1), 0) + + def test_coords(self): + """Test if coords can be recovered with marginalized value had it originally""" + with MarginalModel(coords={"year": [1990, 1991, 1992]}) as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Bernoulli("idx", p=0.75, dims="year") + x = pm.Normal("x", mu=idx, sigma=sigma, dims="year") + + m.marginalize([idx]) + rng = np.random.default_rng(211) + + with m: + prior = pm.sample_prior_predictive( + draws=20, + random_seed=rng, + return_inferencedata=False, + ) + idata = InferenceData( + posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) + ) + + idata = m.recover_marginals(idata, return_samples=True) + post = idata.posterior + assert post.idx.dims == ("chain", "draw", "year") + assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim") + + def test_batched(self): + """Test that marginalization works for batched random variables""" + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2)) + y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3)) + + m.marginalize([idx]) + + rng = np.random.default_rng(211) + + with m: + prior = pm.sample_prior_predictive( + draws=20, + random_seed=rng, + return_inferencedata=False, + ) + idata = InferenceData( + posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) + ) + + idata = m.recover_marginals(idata, return_samples=True) + post = idata.posterior + assert post["y"].shape == (1, 20, 2, 3) + assert post["idx"].shape == (1, 20, 3, 2) + assert post["lp_idx"].shape == (1, 20, 3, 2, 2) + + def test_nested(self): + """Test that marginalization works when there are nested marginalized RVs""" + + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.75) + sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95)) + sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0) + + m.marginalize([idx, sub_idx]) + + rng = np.random.default_rng(211) + + with m: + prior = pm.sample_prior_predictive( + draws=20, + random_seed=rng, + return_inferencedata=False, + ) + idata = InferenceData(posterior=dict_to_dataset(prior)) + + idata = m.recover_marginals(idata, return_samples=True) + post = idata.posterior + assert "idx" in post + assert "lp_idx" in post + assert post.idx.shape == post.y.shape + assert post.lp_idx.shape == (*post.idx.shape, 2) + assert "sub_idx" in post + assert "lp_sub_idx" in post + assert post.sub_idx.shape == post.y.shape + assert post.lp_sub_idx.shape == (*post.sub_idx.shape, 2) + + def true_idx_logp(y): + idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1)) + idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) + return log_softmax(np.stack([idx_0, idx_1]).T, axis=1) + + np.testing.assert_almost_equal( + true_idx_logp(post.y.values.flatten()), + post.lp_idx[0].values, + ) + + def true_sub_idx_logp(y): + sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1)) + sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) + return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1) + + np.testing.assert_almost_equal( + true_sub_idx_logp(post.y.values.flatten()), + post.lp_sub_idx[0].values, + ) + np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0) + np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0) diff --git a/tests/model/test_marginal_model.py b/tests/model/test_marginal_model.py deleted file mode 100644 index 7f97b15b..00000000 --- a/tests/model/test_marginal_model.py +++ /dev/null @@ -1,806 +0,0 @@ -import itertools - -from contextlib import suppress as does_not_warn - -import numpy as np -import pandas as pd -import pymc as pm -import pytensor.tensor as pt -import pytest - -from arviz import InferenceData, dict_to_dataset -from pymc.distributions import transforms -from pymc.logprob.abstract import _logprob -from pymc.model.fgraph import fgraph_from_model -from pymc.pytensorf import inputvars -from pymc.util import UNSET -from scipy.special import log_softmax, logsumexp -from scipy.stats import halfnorm, norm - -from pymc_experimental.distributions import DiscreteMarkovChain -from pymc_experimental.model.marginal_model import ( - FiniteDiscreteMarginalRV, - MarginalModel, - is_conditional_dependent, - marginalize, -) -from tests.utils import equal_computations_up_to_root - - -@pytest.fixture -def disaster_model(): - # fmt: off - disaster_data = pd.Series( - [4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, - 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5, - 2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0, - 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, - 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, - 3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4, - 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1] - ) - # fmt: on - years = np.arange(1851, 1962) - - with MarginalModel() as disaster_model: - switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max()) - early_rate = pm.Exponential("early_rate", 1.0, initval=3) - late_rate = pm.Exponential("late_rate", 1.0, initval=1) - rate = pm.math.switch(switchpoint >= years, early_rate, late_rate) - with pytest.warns(Warning): - disasters = pm.Poisson("disasters", rate, observed=disaster_data) - - return disaster_model, years - - -@pytest.mark.filterwarnings("error") -def test_marginalized_bernoulli_logp(): - """Test logp of IR TestFiniteMarginalDiscreteRV directly""" - mu = pt.vector("mu") - - idx = pm.Bernoulli.dist(0.7, name="idx") - y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y") - marginal_rv_node = FiniteDiscreteMarginalRV( - [mu], - [idx, y], - ndim_supp=0, - n_updates=0, - # Ignore the fact we didn't specify shared RNG input/outputs for idx,y - strict=False, - )(mu)[0].owner - - y_vv = y.clone() - (logp,) = _logprob( - marginal_rv_node.op, - (y_vv,), - *marginal_rv_node.inputs, - ) - - ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=mu, sigma=1.0), y_vv) - np.testing.assert_almost_equal( - logp.eval({mu: [-1, 1], y_vv: 2}), - ref_logp.eval({mu: [-1, 1], y_vv: 2}), - ) - - -@pytest.mark.filterwarnings("error") -def test_marginalized_basic(): - data = [2] * 5 - - with MarginalModel() as m: - sigma = pm.HalfNormal("sigma") - idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6]) - mu = pt.switch( - pt.eq(idx, 0), - -1.0, - pt.switch( - pt.eq(idx, 1), - 0.0, - 1.0, - ), - ) - y = pm.Normal("y", mu=mu, sigma=sigma) - z = pm.Normal("z", y, observed=data) - - m.marginalize([idx]) - assert idx not in m.free_RVs - assert [rv.name for rv in m.marginalized_rvs] == ["idx"] - - # Test logp - with pm.Model() as m_ref: - sigma = pm.HalfNormal("sigma") - y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma) - z = pm.Normal("z", y, observed=data) - - test_point = m_ref.initial_point() - ref_logp = m_ref.compile_logp()(test_point) - ref_dlogp = m_ref.compile_dlogp([m_ref["y"]])(test_point) - - # Assert we can marginalize and unmarginalize internally non-destructively - for i in range(3): - np.testing.assert_almost_equal( - m.compile_logp()(test_point), - ref_logp, - ) - np.testing.assert_almost_equal( - m.compile_dlogp([m["y"]])(test_point), - ref_dlogp, - ) - - -@pytest.mark.filterwarnings("error") -def test_multiple_independent_marginalized_rvs(): - with MarginalModel() as m: - sigma = pm.HalfNormal("sigma") - idx1 = pm.Bernoulli("idx1", p=0.75) - x = pm.Normal("x", mu=idx1, sigma=sigma) - idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,)) - y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,)) - - m.marginalize([idx1, idx2]) - m["x"].owner is not m["y"].owner - _m = m.clone()._marginalize() - _m["x"].owner is not _m["y"].owner - - with pm.Model() as m_ref: - sigma = pm.HalfNormal("sigma") - x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma) - y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,)) - - # Test logp - test_point = m_ref.initial_point() - x_logp, y_logp = m.compile_logp(vars=[m["x"], m["y"]], sum=False)(test_point) - x_ref_log, y_ref_logp = m_ref.compile_logp(vars=[m_ref["x"], m_ref["y"]], sum=False)(test_point) - np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum()) - np.testing.assert_array_almost_equal(y_logp, y_ref_logp) - - -@pytest.mark.filterwarnings("error") -def test_multiple_dependent_marginalized_rvs(): - """Test that marginalization works when there is more than one dependent RV""" - with MarginalModel() as m: - sigma = pm.HalfNormal("sigma") - idx = pm.Bernoulli("idx", p=0.75) - x = pm.Normal("x", mu=idx, sigma=sigma) - y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,)) - - ref_logp_x_y_fn = m.compile_logp([idx, x, y]) - - with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize([idx]) - - m["x"].owner is not m["y"].owner - _m = m.clone()._marginalize() - _m["x"].owner is _m["y"].owner - - tp = m.initial_point() - ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)]) - logp_x_y = m.compile_logp([x, y])(tp) - np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y) - - -def test_rv_dependent_multiple_marginalized_rvs(): - """Test when random variables depend on multiple marginalized variables""" - with MarginalModel() as m: - x = pm.Bernoulli("x", 0.1) - y = pm.Bernoulli("y", 0.3) - z = pm.DiracDelta("z", c=x + y) - - m.marginalize([x, y]) - logp = m.compile_logp() - - np.testing.assert_allclose(np.exp(logp({"z": 0})), 0.9 * 0.7) - np.testing.assert_allclose(np.exp(logp({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7) - np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3) - - -@pytest.mark.filterwarnings("error") -def test_nested_marginalized_rvs(): - """Test that marginalization works when there are nested marginalized RVs""" - - with MarginalModel() as m: - sigma = pm.HalfNormal("sigma") - - idx = pm.Bernoulli("idx", p=0.75) - dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma) - - sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95), shape=(5,)) - sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma, shape=(5,)) - - ref_logp_fn = m.compile_logp(vars=[idx, dep, sub_idx, sub_dep]) - - with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize([idx, sub_idx]) - - assert set(m.marginalized_rvs) == {idx, sub_idx} - - # Test logp - test_point = m.initial_point() - test_point["dep"] = 1000 - test_point["sub_dep"] = np.full((5,), 1000 + 100) - - ref_logp = [ - ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}}) - for idx in (0, 1) - for sub_idxs in itertools.product((0, 1), repeat=5) - ] - logp = m.compile_logp(vars=[dep, sub_dep])(test_point) - - np.testing.assert_almost_equal( - logp, - logsumexp(ref_logp), - ) - - -@pytest.mark.filterwarnings("error") -def test_marginalized_change_point_model(disaster_model): - m, years = disaster_model - - ip = m.initial_point() - ip.pop("switchpoint") - ref_logp_fn = m.compile_logp( - [m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]] - ) - ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years]) - - with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize(m["switchpoint"]) - - logp = m.compile_logp([m["disasters_observed"], m["disasters_unobserved"]])(ip) - np.testing.assert_almost_equal(logp, ref_logp) - - -@pytest.mark.slow -@pytest.mark.filterwarnings("error") -def test_marginalized_change_point_model_sampling(disaster_model): - m, _ = disaster_model - - rng = np.random.default_rng(211) - - with m: - before_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(sample=("draw", "chain")) - - with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize([m["switchpoint"]]) - - with m: - after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(sample=("draw", "chain")) - - np.testing.assert_allclose( - before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2 - ) - np.testing.assert_allclose( - before_marg["late_rate"].mean(), after_marg["late_rate"].mean(), rtol=1e-2 - ) - np.testing.assert_allclose( - before_marg["disasters_unobserved"].mean(), - after_marg["disasters_unobserved"].mean(), - rtol=1e-2, - ) - - -def test_recover_marginals_basic(): - with MarginalModel() as m: - sigma = pm.HalfNormal("sigma") - p = np.array([0.5, 0.2, 0.3]) - k = pm.Categorical("k", p=p) - mu = np.array([-3.0, 0.0, 3.0]) - mu_ = pt.as_tensor_variable(mu) - y = pm.Normal("y", mu=mu_[k], sigma=sigma) - - m.marginalize([k]) - - rng = np.random.default_rng(211) - - with m: - prior = pm.sample_prior_predictive( - draws=20, - random_seed=rng, - return_inferencedata=False, - ) - idata = InferenceData(posterior=dict_to_dataset(prior)) - - idata = m.recover_marginals(idata, return_samples=True) - post = idata.posterior - assert "k" in post - assert "lp_k" in post - assert post.k.shape == post.y.shape - assert post.lp_k.shape == (*post.k.shape, len(p)) - - def true_logp(y, sigma): - y = y.repeat(len(p)).reshape(len(y), -1) - sigma = sigma.repeat(len(p)).reshape(len(sigma), -1) - return log_softmax( - np.log(p) - + norm.logpdf(y, loc=mu, scale=sigma) - + halfnorm.logpdf(sigma) - + np.log(sigma), - axis=1, - ) - - np.testing.assert_almost_equal( - true_logp(post.y.values.flatten(), post.sigma.values.flatten()), - post.lp_k[0].values, - ) - np.testing.assert_almost_equal(logsumexp(post.lp_k, axis=-1), 0) - - -def test_recover_marginals_coords(): - """Test if coords can be recovered with marginalized value had it originally""" - with MarginalModel(coords={"year": [1990, 1991, 1992]}) as m: - sigma = pm.HalfNormal("sigma") - idx = pm.Bernoulli("idx", p=0.75, dims="year") - x = pm.Normal("x", mu=idx, sigma=sigma, dims="year") - - m.marginalize([idx]) - rng = np.random.default_rng(211) - - with m: - prior = pm.sample_prior_predictive( - draws=20, - random_seed=rng, - return_inferencedata=False, - ) - idata = InferenceData( - posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) - ) - - idata = m.recover_marginals(idata, return_samples=True) - post = idata.posterior - assert post.idx.dims == ("chain", "draw", "year") - assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim") - - -def test_recover_batched_marginal(): - """Test that marginalization works for batched random variables""" - with MarginalModel() as m: - sigma = pm.HalfNormal("sigma") - idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2)) - y = pm.Normal("y", mu=idx, sigma=sigma, shape=(3, 2)) - - m.marginalize([idx]) - - rng = np.random.default_rng(211) - - with m: - prior = pm.sample_prior_predictive( - draws=20, - random_seed=rng, - return_inferencedata=False, - ) - idata = InferenceData( - posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) - ) - - idata = m.recover_marginals(idata, return_samples=True) - post = idata.posterior - assert "idx" in post - assert "lp_idx" in post - assert post.idx.shape == post.y.shape - assert post.lp_idx.shape == (*post.idx.shape, 2) - - -@pytest.mark.xfail(reason="Still need to investigate") -def test_nested_recover_marginals(): - """Test that marginalization works when there are nested marginalized RVs""" - - with MarginalModel() as m: - idx = pm.Bernoulli("idx", p=0.75) - sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95)) - sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0) - - m.marginalize([idx, sub_idx]) - - rng = np.random.default_rng(211) - - with m: - prior = pm.sample_prior_predictive( - draws=20, - random_seed=rng, - return_inferencedata=False, - ) - idata = InferenceData(posterior=dict_to_dataset(prior)) - - idata = m.recover_marginals(idata, return_samples=True) - post = idata.posterior - assert "idx" in post - assert "lp_idx" in post - assert post.idx.shape == post.y.shape - assert post.lp_idx.shape == (*post.idx.shape, 2) - assert "sub_idx" in post - assert "lp_sub_idx" in post - assert post.sub_idx.shape == post.y.shape - assert post.lp_sub_idx.shape == (*post.sub_idx.shape, 2) - - def true_idx_logp(y): - idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1)) - idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) - return log_softmax(np.stack([idx_0, idx_1]).T, axis=1) - - np.testing.assert_almost_equal( - true_idx_logp(post.y.values.flatten()), - post.lp_idx[0].values, - ) - - def true_sub_idx_logp(y): - sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1)) - sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) - return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1) - - np.testing.assert_almost_equal( - true_sub_idx_logp(post.y.values.flatten()), - post.lp_sub_idx[0].values, - ) - np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0) - np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0) - - -@pytest.mark.filterwarnings("error") -def test_not_supported_marginalized(): - """Marginalized graphs with non-Elemwise Operations are not supported as they - would violate the batching logp assumption""" - mu = pt.constant([-1, 1]) - - # Allowed, as only elemwise operations connect idx to y - with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - idx = pm.Bernoulli("idx", p=p, size=2) - y = pm.Normal("y", mu=pm.math.switch(idx, 0, 1)) - m.marginalize([idx]) - - # ALlowed, as index operation does not connext idx to y - with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - idx = pm.Bernoulli("idx", p=p, size=2) - y = pm.Normal("y", mu=pm.math.switch(idx, mu[0], mu[1])) - m.marginalize([idx]) - - # Not allowed, as index operation connects idx to y - with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - idx = pm.Bernoulli("idx", p=p, size=2) - # Not allowed - y = pm.Normal("y", mu=mu[idx]) - with pytest.raises(NotImplementedError): - m.marginalize(idx) - - # Not allowed, as index operation connects idx to y, even though there is a - # pure Elemwise connection between the two - with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - idx = pm.Bernoulli("idx", p=p, size=2) - y = pm.Normal("y", mu=mu[idx] + idx) - with pytest.raises(NotImplementedError): - m.marginalize(idx) - - # Multivariate dependent RVs not supported - with MarginalModel() as m: - x = pm.Bernoulli("x", p=0.7) - y = pm.Dirichlet("y", a=pm.math.switch(x, [1, 1, 1], [10, 10, 10])) - with pytest.raises( - NotImplementedError, - match="Marginalization with dependent Multivariate RVs not implemented", - ): - m.marginalize(x) - - -@pytest.mark.filterwarnings("error") -def test_marginalized_deterministic_and_potential(): - rng = np.random.default_rng(299) - - with MarginalModel() as m: - x = pm.Bernoulli("x", p=0.7) - y = pm.Normal("y", x) - z = pm.Normal("z", x) - det = pm.Deterministic("det", y + z) - pot = pm.Potential("pot", y + z + 1) - - with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize([x]) - - y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng) - np.testing.assert_almost_equal(y_draw + z_draw, det_draw) - np.testing.assert_almost_equal(det_draw, pot_draw - 1) - - y_value = m.rvs_to_values[y] - z_value = m.rvs_to_values[z] - det_value, pot_value = m.replace_rvs_by_values([det, pot]) - assert set(inputvars([det_value, pot_value])) == {y_value, z_value} - assert det_value.eval({y_value: 2, z_value: 5}) == 7 - assert pot_value.eval({y_value: 2, z_value: 5}) == 8 - - -@pytest.mark.filterwarnings("error") -def test_not_supported_marginalized_deterministic_and_potential(): - with MarginalModel() as m: - x = pm.Bernoulli("x", p=0.7) - y = pm.Normal("y", x) - det = pm.Deterministic("det", x + y) - - with pytest.raises( - NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det" - ): - m.marginalize([x]) - - with MarginalModel() as m: - x = pm.Bernoulli("x", p=0.7) - y = pm.Normal("y", x) - pot = pm.Potential("pot", x + y) - - with pytest.raises( - NotImplementedError, match="Cannot marginalize x due to dependent Potential pot" - ): - m.marginalize([x]) - - -@pytest.mark.filterwarnings("error") -@pytest.mark.parametrize( - "transform, expected_warning", - ( - (None, does_not_warn()), - (UNSET, does_not_warn()), - (transforms.log, does_not_warn()), - (transforms.Chain([transforms.log, transforms.logodds]), does_not_warn()), - ( - transforms.Interval(0, 1), - pytest.warns( - UserWarning, match="which depends on the marginalized idx may no longer work" - ), - ), - ( - transforms.Chain([transforms.log, transforms.Interval(0, 1)]), - pytest.warns( - UserWarning, match="which depends on the marginalized idx may no longer work" - ), - ), - ), -) -def test_marginalized_transforms(transform, expected_warning): - w = [0.1, 0.3, 0.6] - data = [0, 5, 10] - initval = 0.5 # Value that will be negative on the unconstrained space - - with pm.Model() as m_ref: - sigma = pm.Mixture( - "sigma", - w=w, - comp_dists=pm.HalfNormal.dist([1, 2, 3]), - initval=initval, - default_transform=transform, - ) - y = pm.Normal("y", 0, sigma, observed=data) - - with MarginalModel() as m: - idx = pm.Categorical("idx", p=w) - sigma = pm.HalfNormal( - "sigma", - pt.switch( - pt.eq(idx, 0), - 1, - pt.switch( - pt.eq(idx, 1), - 2, - 3, - ), - ), - initval=initval, - default_transform=transform, - ) - y = pm.Normal("y", 0, sigma, observed=data) - - with expected_warning: - m.marginalize([idx]) - - ip = m.initial_point() - if transform is not None: - if transform is UNSET: - transform_name = "log" - else: - transform_name = transform.name - assert f"sigma_{transform_name}__" in ip - np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip)) - - -def test_is_conditional_dependent_static_shape(): - """Test that we don't consider dependencies through "constant" shape Ops""" - x1 = pt.matrix("x1", shape=(None, 5)) - y1 = pt.random.normal(size=pt.shape(x1)) - assert is_conditional_dependent(y1, x1, [x1, y1]) - - x2 = pt.matrix("x2", shape=(9, 5)) - y2 = pt.random.normal(size=pt.shape(x2)) - assert not is_conditional_dependent(y2, x2, [x2, y2]) - - -def test_data_container(): - """Test that MarginalModel can handle Data containers.""" - with MarginalModel(coords={"obs": [0]}) as marginal_m: - x = pm.Data("x", 2.5) - idx = pm.Bernoulli("idx", p=0.7, dims="obs") - y = pm.Normal("y", idx * x, dims="obs") - - marginal_m.marginalize([idx]) - - logp_fn = marginal_m.compile_logp() - - with pm.Model(coords={"obs": [0]}) as m_ref: - x = pm.Data("x", 2.5) - y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs") - - ref_logp_fn = m_ref.compile_logp() - - for i, x_val in enumerate((-1.5, 2.5, 3.5), start=1): - for m in (marginal_m, m_ref): - m.set_dim("obs", new_length=i, coord_values=tuple(range(i))) - pm.set_data({"x": x_val}, model=m) - - ip = marginal_m.initial_point() - np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip)) - - -@pytest.mark.parametrize("univariate", (True, False)) -def test_vector_univariate_mixture(univariate): - with MarginalModel() as m: - idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ()) - - def dist(idx, size): - return pm.math.switch( - pm.math.eq(idx, 0), - pm.Normal.dist([-10, -10], 1), - pm.Normal.dist([10, 10], 1), - ) - - pm.CustomDist("norm", idx, dist=dist) - - m.marginalize(idx) - logp_fn = m.compile_logp() - - if univariate: - with pm.Model() as ref_m: - pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,)) - else: - with pm.Model() as ref_m: - pm.Mixture( - "norm", - w=[0.5, 0.5], - comp_dists=[ - pm.MvNormal.dist([-10, -10], np.eye(2)), - pm.MvNormal.dist([10, 10], np.eye(2)), - ], - shape=(2,), - ) - ref_logp_fn = ref_m.compile_logp() - - for test_value in ( - [-10, -10], - [10, 10], - [-10, 10], - [-10, 10], - ): - pt = {"norm": test_value} - np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt)) - - -@pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}") -@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}") -def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): - if batch_chain and not batch_emission: - pytest.skip("Redundant implicit combination") - - with MarginalModel() as m: - P = [[0, 1], [1, 0]] - init_dist = pm.Categorical.dist(p=[1, 0]) - chain = DiscreteMarkovChain( - "chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None - ) - emission = pm.Normal( - "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None - ) - - m.marginalize([chain]) - logp_fn = m.compile_logp() - - test_value = np.array([-1, 1, -1, 1]) - expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval() - if batch_emission: - test_value = np.broadcast_to(test_value, (3, 4)) - expected_logp *= 3 - np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) - - -@pytest.mark.parametrize( - "categorical_emission", - [False, True], -) -def test_marginalized_hmm_categorical_emission(categorical_emission): - """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0""" - with MarginalModel() as m: - P = np.array([[0.5, 0.5], [0.3, 0.7]]) - init_dist = pm.Categorical.dist(p=[0.375, 0.625]) - chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) - if categorical_emission: - emission = pm.Categorical( - "emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6]) - ) - else: - emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) - m.marginalize([chain]) - - test_value = np.array([0, 0, 1]) - expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video - logp_fn = m.compile_logp() - np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) - - -@pytest.mark.parametrize("batch_emission1", (False, True)) -@pytest.mark.parametrize("batch_emission2", (False, True)) -def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): - emission1_shape = (2, 4) if batch_emission1 else (4,) - emission2_shape = (2, 4) if batch_emission2 else (4,) - with MarginalModel() as m: - P = [[0, 1], [1, 0]] - init_dist = pm.Categorical.dist(p=[1, 0]) - chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3) - emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape) - emission_2 = pm.Normal( - "emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape - ) - - with pytest.warns(UserWarning, match="multiple dependent variables"): - m.marginalize([chain]) - - logp_fn = m.compile_logp() - - test_value = np.array([-1, 1, -1, 1]) - multiplier = 2 + batch_emission1 + batch_emission2 - expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier - test_value_emission1 = np.broadcast_to(test_value, emission1_shape) - test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) - test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} - np.testing.assert_allclose(logp_fn(test_point), expected_logp) - - -def test_mutable_indexing_jax_backend(): - pytest.importorskip("jax") - from pymc.sampling.jax import get_jaxified_logp - - with MarginalModel() as model: - data = pm.Data("data", np.zeros(10)) - - cat_effect = pm.Normal("cat_effect", sigma=1, shape=5) - cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5)) - - is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10) - pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data) - model.marginalize(["is_outlier"]) - get_jaxified_logp(model) - - -def test_marginal_model_func(): - def create_model(model_class): - with model_class(coords={"trial": range(10)}) as m: - idx = pm.Bernoulli("idx", p=0.5, dims="trial") - mu = pt.where(idx, 1, -1) - sigma = pm.HalfNormal("sigma") - y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10) - return m - - marginal_m = marginalize(create_model(pm.Model), ["idx"]) - assert isinstance(marginal_m, MarginalModel) - - reference_m = create_model(MarginalModel) - reference_m.marginalize(["idx"]) - - # Check forward graph representation is the same - marginal_fgraph, _ = fgraph_from_model(marginal_m) - reference_fgraph, _ = fgraph_from_model(reference_m) - assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs) - - # Check logp graph is the same - # This fails because OpFromGraphs comparison is broken - # assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()]) - ip = marginal_m.initial_point() - np.testing.assert_allclose( - marginal_m.compile_logp()(ip), - reference_m.compile_logp()(ip), - )