diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 26b784b36f..ae44bb65e2 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -24,7 +24,6 @@ import pymc from pymc.aesaraf import extract_obs_data -from pymc.distributions import logpt from pymc.model import modelcontext from pymc.util import get_default_varnames @@ -264,11 +263,15 @@ def _extract_log_likelihood(self, trace): if self.model is None: return None + # TODO: We no longer need one function per observed variable if self.log_likelihood is True: - cached = [(var, self.model.fn(logpt(var))) for var in self.model.observed_RVs] + cached = [ + (var, self.model.fn(self.model.logp_elemwiset(var)[0])) + for var in self.model.observed_RVs + ] else: cached = [ - (var, self.model.fn(logpt(var))) + (var, self.model.fn(self.model.logp_elemwiset(var)[0])) for var in self.model.observed_RVs if var.name in self.log_likelihood ] diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index a0a9eb7e93..956e555d1f 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -15,7 +15,7 @@ from collections.abc import Mapping from functools import singledispatch -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union import aesara.tensor as at import numpy as np @@ -24,10 +24,8 @@ from aeppl.logprob import logcdf as logcdf_aeppl from aeppl.logprob import logprob as logp_aeppl from aeppl.transforms import TransformValuesOpt -from aesara import config from aesara.graph.basic import graph_inputs, io_toposort -from aesara.graph.op import Op, compute_test_value -from aesara.tensor.random.op import RandomVariable +from aesara.graph.op import Op from aesara.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -121,7 +119,7 @@ def _get_scaling(total_size, shape, ndim): def logpt( - var: TensorVariable, + var: Union[TensorVariable, List[TensorVariable]], rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None, *, jacobian: bool = True, @@ -129,7 +127,7 @@ def logpt( transformed: bool = True, sum: bool = True, **kwargs, -) -> TensorVariable: +) -> Union[TensorVariable, List[TensorVariable]]: """Create a measure-space (i.e. log-likelihood) graph for a random variable or a list of random variables at a given point. @@ -156,7 +154,7 @@ def logpt( transformed Apply transforms. sum - Sum the log-likelihood. + Sum the log-likelihood or return each term as a separate list item. """ # TODO: In future when we drop support for tag.value_var most of the following @@ -164,59 +162,64 @@ def logpt( # joint_logprob directly. # If var is not a list make it one. - if not isinstance(var, list): + if not isinstance(var, (list, tuple)): var = [var] - # If logpt isn't provided values and the variable (provided in var) - # is an RV, it is assumed that the tagged value var or observation is - # the value variable for that particular RV. + # If logpt isn't provided values it is assumed that the tagged value var or + # observation is the value variable for that particular RV. if rv_values is None: rv_values = {} - for _var in var: - if isinstance(_var.owner.op, RandomVariable): - rv_value_var = getattr( - _var.tag, "observations", getattr(_var.tag, "value_var", _var) - ) - rv_values = {_var: rv_value_var} + for rv in var: + value_var = getattr(rv.tag, "observations", getattr(rv.tag, "value_var", None)) + if value_var is None: + raise ValueError(f"No value variable found for var {rv}") + rv_values[rv] = value_var + # Else we assume we were given a single rv and respective value elif not isinstance(rv_values, Mapping): - # Else if we're given a single value and a single variable we assume a mapping among them. - rv_values = ( - {var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)} if len(var) == 1 else {} - ) - - # Since the filtering of logp graph is based on value variables - # provided to this function - if not rv_values: - warnings.warn("No value variables provided the logp will be an empty graph") + if len(var) == 1: + rv_values = {var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)} + else: + raise ValueError("rv_values must be a dict if more than one var is requested") if scaling: rv_scalings = {} - for _var in var: - rv_value_var = getattr(_var.tag, "observations", getattr(_var.tag, "value_var", _var)) - rv_scalings[rv_value_var] = _get_scaling( - getattr(_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim + for rv, value_var in rv_values.items(): + rv_scalings[value_var] = _get_scaling( + getattr(rv.tag, "total_size", None), value_var.shape, value_var.ndim ) # Aeppl needs all rv-values pairs, not just that of the requested var. # Hence we iterate through the graph to collect them. tmp_rvs_to_values = rv_values.copy() - transform_map = {} for node in io_toposort(graph_inputs(var), var): try: curr_vars = [node.default_output()] except ValueError: curr_vars = node.outputs for curr_var in curr_vars: - rv_value_var = getattr( + if curr_var in tmp_rvs_to_values: + continue + # Check if variable has a value variable + value_var = getattr( curr_var.tag, "observations", getattr(curr_var.tag, "value_var", None) ) - if rv_value_var is None: - continue - rv_value = rv_values.get(curr_var, rv_value_var) - tmp_rvs_to_values[curr_var] = rv_value - # Along with value variables we also check for transforms if any. - if hasattr(rv_value_var.tag, "transform") and transformed: - transform_map[rv_value] = rv_value_var.tag.transform + if value_var is not None: + tmp_rvs_to_values[curr_var] = value_var + + # After collecting all necessary rvs and values, we check for any value transforms + transform_map = {} + if transformed: + for rv, value_var in tmp_rvs_to_values.items(): + if hasattr(value_var.tag, "transform"): + transform_map[value_var] = value_var.tag.transform + # If the provided value_variable does not have transform information, we + # check if the original `rv.tag.value_var` does. + # TODO: This logic should be replaced by an explicit dict of + # `{value_var: transform}` similar to `rv_values`. + else: + original_value_var = getattr(rv.tag, "value_var", None) + if original_value_var is not None and hasattr(original_value_var.tag, "transform"): + transform_map[value_var] = original_value_var.tag.transform transform_opt = TransformValuesOpt(transform_map) temp_logp_var_dict = factorized_joint_logprob( @@ -224,40 +227,27 @@ def logpt( ) # aeppl returns the logpt for every single value term we provided to it. This includes - # the extra values we plugged in above so we need to filter those out. + # the extra values we plugged in above, so we filter those we actually wanted in the + # same order they were given in. logp_var_dict = {} - for value_var, _logp in temp_logp_var_dict.items(): - if value_var in rv_values.values(): - logp_var_dict[value_var] = _logp + for value_var in rv_values.values(): + logp_var_dict[value_var] = temp_logp_var_dict[value_var] - # If it's an empty dictionary the logp is None - if not logp_var_dict: - logp_var = None - else: - # Otherwise apply appropriate scalings and at.add and/or at.sum the - # graphs accordingly. - if scaling: - for _value in logp_var_dict.keys(): - if _value in rv_scalings: - logp_var_dict[_value] *= rv_scalings[_value] - - if len(logp_var_dict) == 1: - logp_var_dict = tuple(logp_var_dict.values())[0] - if sum: - logp_var = at.sum(logp_var_dict) - else: - logp_var = logp_var_dict - else: - if sum: - logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()]) - else: - logp_var = at.add(*logp_var_dict.values()) + if scaling: + for value_var in logp_var_dict.keys(): + if value_var in rv_scalings: + logp_var_dict[value_var] *= rv_scalings[value_var] - # Recompute test values for the changes introduced by the replacements - # above. - if config.compute_test_value != "off": - for node in io_toposort(graph_inputs((logp_var,)), (logp_var,)): - compute_test_value(node) + if sum: + logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()]) + else: + logp_var = list(logp_var_dict.values()) + # TODO: deprecate special behavior when only one variable is requested and + # always return a list. This is here for backwards compatibility as logpt + # started as a replacement to factor.logpt, but it should now be considered an + # internal function reached only via model.logp* methods. + if len(logp_var) == 1: + logp_var = logp_var[0] return logp_var @@ -276,23 +266,15 @@ def logcdf(rv, value): return logcdf_aeppl(rv, value) -@singledispatch -def _logcdf(op, values, *args, **kwargs): - """Create a log-CDF graph. - - This function dispatches on the type of `op`, which should be a subclass - of `RandomVariable`. If you want to implement new log-CDF graphs - for a `RandomVariable`, register a new function on this dispatcher. - - """ - raise NotImplementedError() - - def logpt_sum(*args, **kwargs): """Return the sum of the logp values for the given observations. Subclasses can use this to improve the speed of logp evaluations if only the sum of the logp values is needed. """ - # TODO: Deprecate this + warnings.warn( + "logpt_sum has been deprecated, you can use logpt instead, which now defaults" + "to the same behavior of logpt_sum", + DeprecationWarning, + ) return logpt(*args, sum=True, **kwargs) diff --git a/pymc/model.py b/pymc/model.py index 4dcb20f16c..dd526561d0 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -284,9 +284,8 @@ def logp(self): """Compiled log probability density function""" return self.model.fn(self.logpt) - @property - def logp_elemwise(self): - return self.model.fn(self.logp_elemwiset) + def logp_elemwise(self, vars=None, jacobian=True): + return self.model.fn(self.logp_elemwiset(vars=vars, jacobian=jacobian)) def dlogp(self, vars=None): """Compiled log probability density gradient function""" @@ -728,6 +727,66 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs): } return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs) + def logp_elemwiset( + self, + vars: Optional[Union[Variable, List[Variable]]] = None, + jacobian: bool = True, + ) -> List[Variable]: + """Elemwise log-probability of the model. + + Parameters + ---------- + vars: list of random variables or potential terms, optional + Compute the gradient with respect to those variables. If None, use all + free and observed random variables, as well as potential terms in model. + jacobian + Whether to include jacobian terms in logprob graph. Defaults to True. + + Returns + ------- + Elemwise logp terms for ecah requested variable, in the same order of input. + """ + if vars is None: + vars = self.free_RVs + self.observed_RVs + self.potentials + elif not isinstance(vars, (list, tuple)): + vars = [vars] + + # We need to separate random variables from potential terms, and remember their + # original order so that we can merge them together in the same order at the end + rv_values = {} + potentials = [] + rv_order, potential_order = [], [] + for i, var in enumerate(vars): + value_var = self.rvs_to_values.get(var) + if value_var is not None: + rv_values[var] = value_var + rv_order.append(i) + else: + if var in self.potentials: + potentials.append(var) + potential_order.append(i) + else: + raise ValueError( + f"Requested variable {var} not found among the model variables" + ) + + rv_logps = [] + if rv_values: + rv_logps = logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian) + if not isinstance(rv_logps, list): + rv_logps = [rv_logps] + + # Replace random variables by their value variables in potential terms + potential_logps = [] + if potentials: + potential_logps, _ = rvs_to_value_vars(potentials, apply_transforms=True) + + logp_elemwise = [None] * len(vars) + for logp_order, logp in zip((rv_order + potential_order), (rv_logps + potential_logps)): + logp_elemwise[logp_order] = logp + + return logp_elemwise + @property def logpt(self): """Aesara scalar of log-probability of the model""" @@ -1239,6 +1298,11 @@ def make_obs_var( ) warnings.warn(impute_message, ImputationWarning) + if rv_var.owner.op.ndim_supp > 0: + raise NotImplementedError( + f"Automatic inputation is only supported for univariate RandomVariables, but {rv_var} is multivariate" + ) + # We can get a random variable comprised of only the unobserved # entries by lifting the indices through the `RandomVariable` `Op`. @@ -1271,11 +1335,21 @@ def make_obs_var( clone=False, ) (observed_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) + # Make a clone of the RV, but change the rng so that observed and missing + # are not treated as equivalent nodes by aesara. This would happen if the + # size of the masked and unmasked array happened to coincide + _, size, _, *inps = observed_rv_var.owner.inputs + rng = self.model.next_rng() + observed_rv_var = observed_rv_var.owner.op(*inps, size=size, rng=rng) + # Add default_update to new rng + new_rng = observed_rv_var.owner.outputs[0] + observed_rv_var.update = (rng, new_rng) + rng.default_update = new_rng observed_rv_var.name = f"{name}_observed" observed_rv_var.tag.observations = nonmissing_data - self.create_value_var(observed_rv_var, transform) + self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data) self.add_random_variable(observed_rv_var, dims) self.observed_RVs.append(observed_rv_var) @@ -1285,22 +1359,21 @@ def make_obs_var( rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var) rv_var = Deterministic(name, rv_var, self, dims, auto=True) - elif sps.issparse(data): - data = sparse.basic.as_sparse(data, name=name) - rv_var.tag.observations = data - self.create_value_var(rv_var, transform) - self.add_random_variable(rv_var, dims) - self.observed_RVs.append(rv_var) else: - data = at.as_tensor_variable(data, name=name) + if sps.issparse(data): + data = sparse.basic.as_sparse(data, name=name) + else: + data = at.as_tensor_variable(data, name=name) rv_var.tag.observations = data - self.create_value_var(rv_var, transform) + self.create_value_var(rv_var, transform=None, value_var=data) self.add_random_variable(rv_var, dims) self.observed_RVs.append(rv_var) return rv_var - def create_value_var(self, rv_var: TensorVariable, transform: Any) -> TensorVariable: + def create_value_var( + self, rv_var: TensorVariable, transform: Any, value_var: Optional[Variable] = None + ) -> TensorVariable: """Create a ``TensorVariable`` that will be used as the random variable's "value" in log-likelihood graphs. @@ -1311,13 +1384,13 @@ def create_value_var(self, rv_var: TensorVariable, transform: Any) -> TensorVari this branch of the conditional. """ - value_var = rv_var.type() + if value_var is None: + value_var = rv_var.type() + value_var.name = rv_var.name if aesara.config.compute_test_value != "off": value_var.tag.test_value = rv_var.tag.test_value - value_var.name = rv_var.name - rv_var.tag.value_var = value_var # Make the value variable a transformed value variable, diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 051e9818fd..a8785f3abd 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -122,7 +122,7 @@ def _get_log_likelihood(model, samples): "Compute log-likelihood for all observations" data = {} for v in model.observed_RVs: - logp_v = replace_shared_variables([logpt(v)]) + logp_v = replace_shared_variables([model.logp_elemwiset(v)[0]]) fgraph = FunctionGraph(model.value_vars, logp_v, clone=False) optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) jax_fn = jax_funcify(fgraph) diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index a4c1ee472d..8f4ad3e2da 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -2521,9 +2521,11 @@ def test_continuous(self): assert logpt(InfBoundedNormal, 0).eval() != -np.inf assert logpt(InfBoundedNormal, 11).eval() != -np.inf - value = at.dscalar("x") + value = model.rvs_to_values[LowerNormalTransform] assert logpt(LowerNormalTransform, value).eval({value: -1}) != -np.inf + value = model.rvs_to_values[UpperNormalTransform] assert logpt(UpperNormalTransform, value).eval({value: 1}) != -np.inf + value = model.rvs_to_values[BoundedNormalTransform] assert logpt(BoundedNormalTransform, value).eval({value: 0}) != -np.inf assert logpt(BoundedNormalTransform, value).eval({value: 11}) != -np.inf diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 0216de75fe..ab97fa26c4 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -45,7 +45,7 @@ def random_polyagamma(*args, **kwargs): from pymc.distributions.continuous import get_tau_sigma, interpolated from pymc.distributions.discrete import _OrderedLogistic, _OrderedProbit from pymc.distributions.dist_math import clipped_beta_rvs -from pymc.distributions.logprob import logpt +from pymc.distributions.logprob import logp from pymc.distributions.multivariate import _OrderedMultinomial, quaddist_matrix from pymc.distributions.shape_utils import to_tuple from pymc.tests.helpers import SeededTest, select_by_precision @@ -1626,8 +1626,8 @@ def test_errors(self): rowcov=np.eye(3), colcov=np.eye(3), ) - with pytest.raises(TypeError): - logpt(matrixnormal, aesara.tensor.ones((3, 3, 3))) + with pytest.raises(ValueError): + logp(matrixnormal, aesara.tensor.ones((3, 3, 3))) with pm.Model(): with pytest.warns(FutureWarning): @@ -1856,7 +1856,7 @@ def test_density_dist_without_random(self): pm.DensityDist( "density_dist", mu, - logp=lambda value, mu: logpt(pm.Normal.dist(mu, 1, size=100), value), + logp=lambda value, mu: logp(pm.Normal.dist(mu, 1, size=100), value), observed=np.random.randn(100), initval=0, ) diff --git a/pymc/tests/test_idata_conversion.py b/pymc/tests/test_idata_conversion.py index bfd275c311..f950f994b9 100644 --- a/pymc/tests/test_idata_conversion.py +++ b/pymc/tests/test_idata_conversion.py @@ -143,6 +143,11 @@ def test_to_idata(self, data, eight_schools_params, chains, draws): np.isclose(ivalues[chain], values[chain * draws : (chain + 1) * draws]) ) + chains = inference_data.posterior.dims["chain"] + draws = inference_data.posterior.dims["draw"] + obs = inference_data.observed_data["obs"] + assert inference_data.log_likelihood["obs"].shape == (chains, draws) + obs.shape + def test_predictions_to_idata(self, data, eight_schools_params): "Test that we can add predictions to a previously-existing InferenceData." test_dict = { @@ -329,6 +334,11 @@ def test_missing_data_model(self): fails = check_multiple_attrs(test_dict, inference_data) assert not fails + # The missing part of partial observed RVs is not included in log_likelihood + # See https://github.com/pymc-devs/pymc/issues/5255 + assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3) + + @pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4") @pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4") def test_mv_missing_data_model(self): data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1) @@ -375,8 +385,12 @@ def test_multiple_observed_rv(self, log_likelihood): if not log_likelihood: test_dict.pop("log_likelihood") test_dict["~log_likelihood"] = [] - if isinstance(log_likelihood, list): + elif isinstance(log_likelihood, list): test_dict["log_likelihood"] = ["y1", "~y2"] + assert inference_data.log_likelihood["y1"].shape == (2, 100, 10) + else: + assert inference_data.log_likelihood["y1"].shape == (2, 100, 10) + assert inference_data.log_likelihood["y2"].shape == (2, 100, 100) fails = check_multiple_attrs(test_dict, inference_data) assert not fails @@ -445,12 +459,12 @@ def test_single_observation(self): inference_data = pm.sample(500, chains=2, return_inferencedata=True) assert inference_data + assert inference_data.log_likelihood["w"].shape == (2, 500, 1) - @pytest.mark.xfail(reason="Potential not refactored for v4") def test_potential(self): with pm.Model(): x = pm.Normal("x", 0.0, 1.0) - pm.Potential("z", logpt(pm.Normal.dist(x, 1.0), np.random.randn(10))) + pm.Potential("z", pm.logp(pm.Normal.dist(x, 1.0), np.random.randn(10))) inference_data = pm.sample(100, chains=2, return_inferencedata=True) assert inference_data @@ -463,7 +477,7 @@ def test_constant_data(self, use_context): y = pm.Data("y", [1.0, 2.0, 3.0]) beta = pm.Normal("beta", 0, 1) obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable - trace = pm.sample(100, tune=100, return_inferencedata=False) + trace = pm.sample(100, chains=2, tune=100, return_inferencedata=False) if use_context: inference_data = to_inference_data(trace=trace) @@ -472,6 +486,7 @@ def test_constant_data(self, use_context): test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]} fails = check_multiple_attrs(test_dict, inference_data) assert not fails + assert inference_data.log_likelihood["obs"].shape == (2, 100, 3) def test_predictions_constant_data(self): with pm.Model(): @@ -570,7 +585,7 @@ def test_multivariate_observations(self): with pm.Model(coords=coords): p = pm.Beta("p", 1, 1, size=(3,)) pm.Multinomial("y", 20, p, dims=("experiment", "direction"), observed=data) - idata = pm.sample(draws=50, tune=100, return_inferencedata=True) + idata = pm.sample(draws=50, chains=2, tune=100, return_inferencedata=True) test_dict = { "posterior": ["p"], "sample_stats": ["lp"], @@ -581,6 +596,7 @@ def test_multivariate_observations(self): assert not fails assert "direction" not in idata.log_likelihood.dims assert "direction" in idata.observed_data.dims + assert idata.log_likelihood["y"].shape == (2, 50, 20) def test_constant_data_coords_issue_5046(self): """This is a regression test against a bug where a local coords variable was overwritten.""" diff --git a/pymc/tests/test_logprob.py b/pymc/tests/test_logprob.py index c12fbc92cd..53a1061cb0 100644 --- a/pymc/tests/test_logprob.py +++ b/pymc/tests/test_logprob.py @@ -144,7 +144,8 @@ def test_logpt_subtensor(): I_value_var = I_rv.type() I_value_var.name = "I_value" - A_idx_logp = logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False) + A_idx_logps = logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False) + A_idx_logp = at.add(*A_idx_logps) logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp) diff --git a/pymc/tests/test_missing.py b/pymc/tests/test_missing.py index 2769b88623..115175df24 100644 --- a/pymc/tests/test_missing.py +++ b/pymc/tests/test_missing.py @@ -16,11 +16,13 @@ import numpy as np import pandas as pd import pytest +import scipy.stats +from aesara.graph import graph_inputs from numpy import array, ma -from pymc.distributions.continuous import Gamma, Normal, Uniform -from pymc.distributions.transforms import interval +from pymc import logpt +from pymc.distributions import Dirichlet, Gamma, Normal, Uniform from pymc.exceptions import ImputationWarning from pymc.model import Model from pymc.sampling import sample, sample_posterior_predictive, sample_prior_predictive @@ -94,10 +96,10 @@ def test_interval_missing_observations(): with pytest.warns(ImputationWarning): theta2 = Normal("theta2", mu=theta1, observed=obs2, rng=rng) - assert "theta1_observed_interval__" in model.named_vars + assert "theta1_observed" in model.named_vars assert "theta1_missing_interval__" in model.named_vars - assert isinstance( - model.rvs_to_values[model.named_vars["theta1_observed"]].tag.transform, interval + assert not hasattr( + model.rvs_to_values[model.named_vars["theta1_observed"]].tag, "transform" ) prior_trace = sample_prior_predictive(return_inferencedata=False) @@ -164,3 +166,69 @@ def test_missing_logp(): m_missing_logp = m_missing.logp({"theta1_missing": [2, 4], "theta2_missing": [0, 1, 3]}) assert m_logp == m_missing_logp + + +def test_missing_multivariate(): + """Test model with missing variables whose transform changes base shape still works""" + + with Model() as m_miss: + with pytest.raises( + NotImplementedError, + match="Automatic inputation is only supported for univariate RandomVariables", + ): + x = Dirichlet( + "x", a=[1, 2, 3], observed=np.array([[0.3, 0.3, 0.4], [np.nan, np.nan, np.nan]]) + ) + + # TODO: Test can be used when local_subtensor_rv_lift supports multivariate distributions + # from pymc.distributions.transforms import simplex + # + # with Model() as m_unobs: + # x = Dirichlet("x", a=[1, 2, 3]) + # + # inp_vals = simplex.forward(np.array([0.3, 0.3, 0.4])).eval() + # assert np.isclose( + # m_miss.logp({"x_missing_simplex__": inp_vals}), + # m_unobs.logp_nojac({"x_simplex__": inp_vals}) * 2, + # ) + + +def test_missing_vector_parameter(): + with Model() as m: + x = Normal( + "x", + np.array([-10, 10]), + 0.1, + observed=np.array([[np.nan, 10], [-10, np.nan], [np.nan, np.nan]]), + ) + x_draws = x.eval() + assert x_draws.shape == (3, 2) + assert np.all(x_draws[:, 0] < 0) + assert np.all(x_draws[:, 1] > 0) + assert np.isclose( + m.logp({"x_missing": np.array([-10, 10, -10, 10])}), + scipy.stats.norm(scale=0.1).logpdf(0) * 6, + ) + + +def test_missing_symmetric(): + """Check that logpt works when partially observed variable have equal observed and + unobserved dimensions. + + This would fail in a previous implementation because the two variables would be + equivalent and one of them would be discarded during MergeOptimization while + buling the logpt graph + """ + with Model() as m: + x = Gamma("x", alpha=3, beta=10, observed=np.array([1, np.nan])) + + x_obs_rv = m["x_observed"] + x_obs_vv = m.rvs_to_values[x_obs_rv] + + x_unobs_rv = m["x_missing"] + x_unobs_vv = m.rvs_to_values[x_unobs_rv] + + logp = logpt([x_obs_rv, x_unobs_rv], {x_obs_rv: x_obs_vv, x_unobs_rv: x_unobs_vv}) + logp_inputs = list(graph_inputs([logp])) + assert x_obs_vv in logp_inputs + assert x_unobs_vv in logp_inputs diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index e862018ba6..2d4a6e060c 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -409,12 +409,12 @@ def test_multiple_simulators(self): a_val = m.rvs_to_values[a] sim1_val = m.rvs_to_values[sim1] logp_sim1 = pm.logpt(sim1, sim1_val) - logp_sim1_fn = aesara.function([sim1_val, a_val], logp_sim1) + logp_sim1_fn = aesara.function([a_val], logp_sim1) b_val = m.rvs_to_values[b] sim2_val = m.rvs_to_values[sim2] logp_sim2 = pm.logpt(sim2, sim2_val) - logp_sim2_fn = aesara.function([sim2_val, b_val], logp_sim2) + logp_sim2_fn = aesara.function([b_val], logp_sim2) assert any( node for node in logp_sim1_fn.maker.fgraph.toposort() if isinstance(node.op, SortOp) diff --git a/pymc/tests/test_transforms.py b/pymc/tests/test_transforms.py index 8fafd50e9e..d38300cdc9 100644 --- a/pymc/tests/test_transforms.py +++ b/pymc/tests/test_transforms.py @@ -23,7 +23,7 @@ import pymc as pm import pymc.distributions.transforms as tr -from pymc.aesaraf import jacobian +from pymc.aesaraf import floatX, jacobian from pymc.distributions import logpt from pymc.tests.checks import close_to, close_to_logical from pymc.tests.helpers import SeededTest @@ -285,40 +285,46 @@ def build_model(self, distfam, params, size, transform, initval=None): def check_transform_elementwise_logp(self, model): x = model.free_RVs[0] - x0 = x.tag.value_var - assert x.ndim == logpt(x, sum=False).ndim + x_val_transf = x.tag.value_var - pt = model.initial_point - array = np.random.randn(*pt[x0.name].shape) - transform = x0.tag.transform - logp_notrans = logpt(x, transform.backward(array, *x.owner.inputs), transformed=False) + pt = model.recompute_initial_point(0) + test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape)) + transform = x_val_transf.tag.transform + test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval() - jacob_det = transform.log_jac_det(aesara.shared(array), *x.owner.inputs) - assert logpt(x, sum=False).ndim == jacob_det.ndim + # Create input variable with same dimensionality as untransformed test_array + x_val_untransf = at.constant(test_array_untransf).type() - v1 = logpt(x, array, jacobian=False).eval() - v2 = logp_notrans.eval() + jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) + assert logpt(x, sum=False).ndim == x.ndim == jacob_det.ndim + + v1 = logpt(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf}) + v2 = logpt(x, x_val_untransf, transformed=False).eval({x_val_untransf: test_array_untransf}) close_to(v1, v2, tol) - def check_vectortransform_elementwise_logp(self, model, vect_opt=0): + def check_vectortransform_elementwise_logp(self, model): x = model.free_RVs[0] - x0 = x.tag.value_var - # TODO: For some reason the ndim relations - # dont hold up here. But final log-probablity - # values are what we expected. - # assert (x.ndim - 1) == logpt(x, sum=False).ndim - - pt = model.initial_point - array = np.random.randn(*pt[x0.name].shape) - transform = x0.tag.transform - logp_nojac = logpt(x, transform.backward(array, *x.owner.inputs), transformed=False) - - jacob_det = transform.log_jac_det(aesara.shared(array), *x.owner.inputs) - # assert logpt(x).ndim == jacob_det.ndim - + x_val_transf = x.tag.value_var + + pt = model.recompute_initial_point(0) + test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape)) + transform = x_val_transf.tag.transform + test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval() + + # Create input variable with same dimensionality as untransformed test_array + x_val_untransf = at.constant(test_array_untransf).type() + + jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) + # Original distribution is univariate + if x.owner.op.ndim_supp == 0: + assert logpt(x, sum=False).ndim == x.ndim == (jacob_det.ndim + 1) + # Original distribution is multivariate + else: + assert logpt(x, sum=False).ndim == (x.ndim - 1) == jacob_det.ndim + + a = logpt(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf}) + b = logpt(x, x_val_untransf, transformed=False).eval({x_val_untransf: test_array_untransf}) # Hack to get relative tolerance - a = logpt(x, array.astype(aesara.config.floatX), jacobian=False).eval() - b = logp_nojac.eval() close_to(a, b, np.abs(0.5 * (a + b) * tol)) @pytest.mark.parametrize( @@ -406,7 +412,7 @@ def test_vonmises(self, mu, kappa, size): ) def test_dirichlet(self, a, size): model = self.build_model(pm.Dirichlet, {"a": a}, size=size, transform=tr.simplex) - self.check_vectortransform_elementwise_logp(model, vect_opt=1) + self.check_vectortransform_elementwise_logp(model) def test_normal_ordered(self): model = self.build_model( @@ -416,7 +422,7 @@ def test_normal_ordered(self): initval=np.asarray([-1.0, 1.0, 4.0]), transform=tr.ordered, ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "sd,size", @@ -434,7 +440,7 @@ def test_half_normal_ordered(self, sd, size): initval=initval, transform=tr.Chain([tr.log, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize("lam,size", [(2.5, (2,)), (np.ones(3), (4, 3))]) def test_exponential_ordered(self, lam, size): @@ -446,7 +452,7 @@ def test_exponential_ordered(self, lam, size): initval=initval, transform=tr.Chain([tr.log, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "a,b,size", @@ -468,7 +474,7 @@ def test_beta_ordered(self, a, b, size): initval=initval, transform=tr.Chain([tr.logodds, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "lower,upper,size", @@ -491,7 +497,7 @@ def transform_params(*inputs): initval=initval, transform=tr.Chain([interval, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=1) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize("mu,kappa,size", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))]) def test_vonmises_ordered(self, mu, kappa, size): @@ -503,7 +509,7 @@ def test_vonmises_ordered(self, mu, kappa, size): initval=initval, transform=tr.Chain([tr.circular, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "lower,upper,size,transform", @@ -522,7 +528,7 @@ def test_uniform_other(self, lower, upper, size, transform): initval=initval, transform=transform, ) - self.check_vectortransform_elementwise_logp(model, vect_opt=1) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "mu,cov,size,shape", @@ -536,7 +542,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape): model = self.build_model( pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, transform=tr.ordered ) - self.check_vectortransform_elementwise_logp(model, vect_opt=1) + self.check_vectortransform_elementwise_logp(model) def test_triangular_transform():