From b39f0d5aa0855aeb9962f51c922c7b2733a2903f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Nov 2023 14:55:49 +0100 Subject: [PATCH 1/4] Add test for univariate and multivariate marginal mixture Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- .../tests/model/test_marginal_model.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index 2d0ad773..11df660e 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -629,3 +629,47 @@ def test_data_container(): ip = marginal_m.initial_point() np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip)) + + +@pytest.mark.parametrize("univariate", (True, False)) +def test_vector_univariate_mixture(univariate): + + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ()) + + def dist(idx, size): + return pm.math.switch( + pm.math.eq(idx, 0), + pm.Normal.dist([-10, -10], 1), + pm.Normal.dist([10, 10], 1), + ) + + pm.CustomDist("norm", idx, dist=dist) + + m.marginalize(idx) + logp_fn = m.compile_logp() + + if univariate: + with pm.Model() as ref_m: + pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,)) + else: + with pm.Model() as ref_m: + pm.Mixture( + "norm", + w=[0.5, 0.5], + comp_dists=[ + pm.MvNormal.dist([-10, -10], np.eye(2)), + pm.MvNormal.dist([10, 10], np.eye(2)), + ], + shape=(2,), + ) + ref_logp_fn = ref_m.compile_logp() + + for test_value in ( + [-10, -10], + [10, 10], + [-10, 10], + [-10, 10], + ): + pt = {"norm": test_value} + np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt)) From af42d7d4bc4d601551497d5866d81ba661ffbd28 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 16 Jan 2024 16:59:13 +0100 Subject: [PATCH 2/4] Minor cleanup MarginalModel --- pymc_experimental/model/marginal_model.py | 9 ++++++--- pymc_experimental/tests/model/test_marginal_model.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index 068d0c61..1e5194cc 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -582,10 +582,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs} - if max(ndim_supp) > 0: + if len(ndim_supp) != 1: raise NotImplementedError( - "Marginalization of withe dependent Multivariate RVs not implemented" + "Marginalization with dependent variables of different support dimensionality not implemented" ) + [ndim_supp] = ndim_supp + if ndim_supp > 0: + raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented") marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) dependent_rvs_input_rvs = [ @@ -623,7 +626,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs marginalization_op = FiniteDiscreteMarginalRV( inputs=list(replace_inputs.values()), outputs=cloned_outputs, - ndim_supp=0, + ndim_supp=ndim_supp, ) marginalized_rvs = marginalization_op(*replace_inputs.keys()) fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index 11df660e..b664feca 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -472,7 +472,7 @@ def test_not_supported_marginalized(): y = pm.Dirichlet("y", a=pm.math.switch(x, [1, 1, 1], [10, 10, 10])) with pytest.raises( NotImplementedError, - match="Marginalization of withe dependent Multivariate RVs not implemented", + match="Marginalization with dependent Multivariate RVs not implemented", ): m.marginalize(x) From c61f2cbfd77e144f01f8315d70a260c152f60e2a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 16 Jan 2024 18:01:33 +0100 Subject: [PATCH 3/4] Refactor logic to reduce add batched logp dimensions --- pymc_experimental/model/marginal_model.py | 85 +++++++++++++---------- 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index 1e5194cc..a250c47a 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -26,7 +26,7 @@ vectorize_graph, ) from pytensor.scan import map as scan_map -from pytensor.tensor import TensorVariable +from pytensor.tensor import TensorType, TensorVariable from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.shape import Shape from pytensor.tensor.special import log_softmax @@ -381,41 +381,36 @@ def transform_input(inputs): rv_dict = {} rv_dims = {} - for seed, rv in zip(seeds, vars_to_recover): + for seed, marginalized_rv in zip(seeds, vars_to_recover): supported_dists = (Bernoulli, Categorical, DiscreteUniform) - if not isinstance(rv.owner.op, supported_dists): + if not isinstance(marginalized_rv.owner.op, supported_dists): raise NotImplementedError( - f"RV with distribution {rv.owner.op} cannot be recovered. " + f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. " f"Supported distribution include {supported_dists}" ) m = self.clone() - rv = m.vars_to_clone[rv] - m.unmarginalize([rv]) - dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs) - joint_logps = m.logp(vars=dependent_vars + [rv], sum=False) + marginalized_rv = m.vars_to_clone[marginalized_rv] + m.unmarginalize([marginalized_rv]) + dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) + joint_logps = m.logp(vars=[marginalized_rv] + dependent_vars, sum=False) - marginalized_value = m.rvs_to_values[rv] + marginalized_value = m.rvs_to_values[marginalized_rv] other_values = [v for v in m.value_vars if v is not marginalized_value] # Handle batch dims for marginalized value and its dependent RVs - joint_logp = joint_logps[-1] - for dv in joint_logps[:-1]: - dbcast = dv.type.broadcastable - mbcast = marginalized_value.type.broadcastable - mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast - values_axis_bcast = [ - i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v - ] - joint_logp += dv.sum(values_axis_bcast) + marginalized_logp, *dependent_logps = joint_logps + joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( + marginalized_rv.type, dependent_logps + ) - rv_shape = constant_fold(tuple(rv.shape)) - rv_domain = get_domain_of_finite_discrete_rv(rv) + rv_shape = constant_fold(tuple(marginalized_rv.shape)) + rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) rv_domain_tensor = pt.moveaxis( pt.full( (*rv_shape, len(rv_domain)), rv_domain, - dtype=rv.dtype, + dtype=marginalized_rv.dtype, ), -1, 0, @@ -431,7 +426,7 @@ def transform_input(inputs): joint_logps_norm = log_softmax(joint_logps, axis=-1) if return_samples: sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps) - if isinstance(rv.owner.op, DiscreteUniform): + if isinstance(marginalized_rv.owner.op, DiscreteUniform): sample_rv_outs += rv_domain[0] rv_loglike_fn = compile_pymc( @@ -456,18 +451,20 @@ def transform_input(inputs): logps, samples = zip(*logvs) logps = np.array(logps) samples = np.array(samples) - rv_dict[rv.name] = samples.reshape( + rv_dict[marginalized_rv.name] = samples.reshape( tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:], ) else: logps = np.array(logvs) - rv_dict["lp_" + rv.name] = logps.reshape( + rv_dict["lp_" + marginalized_rv.name] = logps.reshape( tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:], ) - if rv.name in m.named_vars_to_dims: - rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name]) - rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"] + if marginalized_rv.name in m.named_vars_to_dims: + rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name]) + rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [ + "lp_" + marginalized_rv.name + "_dim" + ] coords, dims = coords_and_dims_for_inferencedata(self) dims.update(rv_dims) @@ -647,6 +644,22 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: raise NotImplementedError(f"Cannot compute domain for op {op}") +def _add_reduce_batch_dependent_logps( + marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] +): + """Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`.""" + + mbcast = marginalized_type.broadcastable + reduced_logps = [] + for dependent_logp in dependent_logps: + dbcast = dependent_logp.type.broadcastable + dim_diff = len(dbcast) - len(mbcast) + mbcast_aligned = (True,) * dim_diff + mbcast + vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v] + reduced_logps.append(dependent_logp.sum(vbcast_axis)) + return pt.add(*reduced_logps) + + @_logprob.register(FiniteDiscreteMarginalRV) def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): # Clone the inner RV graph of the Marginalized RV @@ -662,17 +675,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs) # Reduce logp dimensions corresponding to broadcasted variables - joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]] - for inner_rv, inner_value in inner_rvs_to_values.items(): - if inner_rv is marginalized_rv: - continue - vbcast = inner_value.type.broadcastable - mbcast = marginalized_rv.type.broadcastable - mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast - values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v] - joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True) - - # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different + marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv]) + joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( + marginalized_rv.type, logps_dict.values() + ) + + # Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different # values of the marginalized RV # Some inputs are not root inputs (such as transformed projections of value variables) # Or cannot be used as inputs to an OpFromGraph (shared variables and constants) @@ -700,6 +708,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): ) # Arbitrary cutoff to switch to Scan implementation to keep graph size under control + # TODO: Try vectorize here if len(marginalized_rv_domain) <= 10: joint_logps = [ joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs) From 3ebdfb5855c44730c3e79384757f604db8f57234 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Nov 2023 16:52:29 +0100 Subject: [PATCH 4/4] Marginalize DiscreteMarkovChain Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pymc_experimental/model/marginal_model.py | 115 +++++++++++++++--- .../tests/model/test_marginal_model.py | 85 +++++++++++++ 2 files changed, 185 insertions(+), 15 deletions(-) diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index a250c47a..7ee047b3 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -10,21 +10,16 @@ from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain from pymc.logprob.abstract import _logprob -from pymc.logprob.basic import conditional_logp +from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Model from pymc.pytensorf import compile_pymc, constant_fold, inputvars from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict -from pytensor import Mode +from pytensor import Mode, scan from pytensor.compile import SharedVariable from pytensor.compile.builders import OpFromGraph -from pytensor.graph import ( - Constant, - FunctionGraph, - ancestors, - clone_replace, - vectorize_graph, -) +from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace +from pytensor.graph.replace import vectorize_graph from pytensor.scan import map as scan_map from pytensor.tensor import TensorType, TensorVariable from pytensor.tensor.elemwise import Elemwise @@ -33,6 +28,8 @@ __all__ = ["MarginalModel"] +from pymc_experimental.distributions import DiscreteMarkovChain + class MarginalModel(Model): """Subclass of PyMC Model that implements functionality for automatic @@ -247,16 +244,25 @@ def marginalize( self[var] if isinstance(var, str) else var for var in rvs_to_marginalize ] - supported_dists = (Bernoulli, Categorical, DiscreteUniform) for rv_to_marginalize in rvs_to_marginalize: if rv_to_marginalize not in self.free_RVs: raise ValueError( f"Marginalized RV {rv_to_marginalize} is not a free RV in the model" ) - if not isinstance(rv_to_marginalize.owner.op, supported_dists): + + rv_op = rv_to_marginalize.owner.op + if isinstance(rv_op, DiscreteMarkovChain): + if rv_op.n_lags > 1: + raise NotImplementedError( + "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" + ) + if rv_to_marginalize.owner.inputs[0].type.ndim > 2: + raise NotImplementedError( + "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" + ) + elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)): raise NotImplementedError( - f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. " - f"Supported distribution include {supported_dists}" + f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" ) if rv_to_marginalize.name in self.named_vars_to_dims: @@ -492,6 +498,10 @@ class FiniteDiscreteMarginalRV(MarginalRV): """Base class for Finite Discrete Marginalized RVs""" +class DiscreteMarginalMarkovChainRV(MarginalRV): + """Base class for Discrete Marginal Markov Chain RVs""" + + def static_shape_ancestors(vars): """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" return [ @@ -620,11 +630,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs}) cloned_outputs = clone_replace(outputs, replace=replace_inputs) - marginalization_op = FiniteDiscreteMarginalRV( + if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): + marginalize_constructor = DiscreteMarginalMarkovChainRV + else: + marginalize_constructor = FiniteDiscreteMarginalRV + + marginalization_op = marginalize_constructor( inputs=list(replace_inputs.values()), outputs=cloned_outputs, ndim_supp=ndim_supp, ) + marginalized_rvs = marginalization_op(*replace_inputs.keys()) fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) return rvs_to_marginalize, marginalized_rvs @@ -640,6 +656,9 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: elif isinstance(op, DiscreteUniform): lower, upper = constant_fold(rv.owner.inputs[3:]) return tuple(range(lower, upper + 1)) + elif isinstance(op, DiscreteMarkovChain): + P = rv.owner.inputs[0] + return tuple(range(pt.get_vector_length(P[-1]))) raise NotImplementedError(f"Cannot compute domain for op {op}") @@ -647,7 +666,7 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: def _add_reduce_batch_dependent_logps( marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] ): - """Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`.""" + """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`.""" mbcast = marginalized_type.broadcastable reduced_logps = [] @@ -730,3 +749,69 @@ def logp_fn(marginalized_rv_const, *non_sequences): # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise return joint_logps, *(pt.constant(0),) * (len(values) - 1) + + +@_logprob.register(DiscreteMarginalMarkovChainRV) +def marginal_hmm_logp(op, values, *inputs, **kwargs): + + marginalized_rvs_node = op.make_node(*inputs) + inner_rvs = clone_replace( + op.inner_outputs, + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, + ) + + chain_rv, *dependent_rvs = inner_rvs + P, n_steps_, init_dist_, rng = chain_rv.owner.inputs + domain = pt.arange(P.shape[-1], dtype="int32") + + # Construct logp in two steps + # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission) + + # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating + # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise, + # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step. + chain_value = chain_rv.clone() + dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value}) + logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) + + # Reduce and add the batch dims beyond the chain dimension + reduced_logp_emissions = _add_reduce_batch_dependent_logps( + chain_rv.type, logp_emissions_dict.values() + ) + + # Add a batch dimension for the domain of the chain + chain_shape = constant_fold(tuple(chain_rv.shape)) + batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) + batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) + + # Step 2: Compute the transition probabilities + # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) + # We do it entirely in logs, though. + + # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under + # the initial distribution. This is robust to everything the user can throw at it. + batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")( + batch_chain_value[..., 0] + ) + log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0] + + def step_alpha(logp_emission, log_alpha, log_P): + step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0) + return logp_emission + step_log_prob + + P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2) + log_P = pt.shape_padright(pt.log(P), P_bcast_dims) + log_alpha_seq, _ = scan( + step_alpha, + non_sequences=[log_P], + outputs_info=[log_alpha_init], + # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value + sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0), + ) + # Final logp is just the sum of the last scan state + joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) + + # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first + # return is the joint probability of everything together, but PyMC still expects one logp for each one. + dummy_logps = (pt.constant(0),) * (len(values) - 1) + return joint_logp, *dummy_logps diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index b664feca..c0e1bd90 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -14,6 +14,7 @@ from scipy.special import log_softmax, logsumexp from scipy.stats import halfnorm, norm +from pymc_experimental.distributions import DiscreteMarkovChain from pymc_experimental.model.marginal_model import ( FiniteDiscreteMarginalRV, MarginalModel, @@ -673,3 +674,87 @@ def dist(idx, size): ): pt = {"norm": test_value} np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt)) + + +@pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}") +@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}") +def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): + if batch_chain and not batch_emission: + pytest.skip("Redundant implicit combination") + + with MarginalModel() as m: + P = [[0, 1], [1, 0]] + init_dist = pm.Categorical.dist(p=[1, 0]) + chain = DiscreteMarkovChain( + "chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None + ) + emission = pm.Normal( + "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None + ) + + m.marginalize([chain]) + logp_fn = m.compile_logp() + + test_value = np.array([-1, 1, -1, 1]) + expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval() + if batch_emission: + test_value = np.broadcast_to(test_value, (3, 4)) + expected_logp *= 3 + np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp) + + +@pytest.mark.parametrize( + "categorical_emission", + [ + False, + # Categorical has a core vector parameter, + # so it is not possible to build a graph that uses elemwise operations exclusively + pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError)), + ], +) +def test_marginalized_hmm_categorical_emission(categorical_emission): + """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0""" + with MarginalModel() as m: + P = np.array([[0.5, 0.5], [0.3, 0.7]]) + init_dist = pm.Categorical.dist(p=[0.375, 0.625]) + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) + if categorical_emission: + emission = pm.Categorical( + "emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6]) + ) + else: + emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) + m.marginalize([chain]) + + test_value = np.array([0, 0, 1]) + expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video + logp_fn = m.compile_logp() + np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp) + + +@pytest.mark.parametrize("batch_emission1", (False, True)) +@pytest.mark.parametrize("batch_emission2", (False, True)) +def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): + emission1_shape = (2, 4) if batch_emission1 else (4,) + emission2_shape = (2, 4) if batch_emission2 else (4,) + with MarginalModel() as m: + P = [[0, 1], [1, 0]] + init_dist = pm.Categorical.dist(p=[1, 0]) + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3) + emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape) + emission_2 = pm.Normal( + "emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape + ) + + with pytest.warns(UserWarning, match="multiple dependent variables"): + m.marginalize([chain]) + + logp_fn = m.compile_logp() + + test_value = np.array([-1, 1, -1, 1]) + multiplier = 2 + batch_emission1 + batch_emission2 + expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier + test_value_emission1 = np.broadcast_to(test_value, emission1_shape) + test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) + test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} + np.testing.assert_allclose(logp_fn(test_point), expected_logp)