diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 0f633ac2..8bb1afe9 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -30,8 +30,9 @@ Distributions GenExtreme GeneralizedPoisson - histogram_utils.histogram_approximation DiscreteMarkovChain + R2D2M2CP + histogram_approximation Utils diff --git a/pymc_experimental/distributions/__init__.py b/pymc_experimental/distributions/__init__.py index 23c35ccd..468c4cc0 100644 --- a/pymc_experimental/distributions/__init__.py +++ b/pymc_experimental/distributions/__init__.py @@ -19,10 +19,14 @@ from pymc_experimental.distributions.continuous import GenExtreme from pymc_experimental.distributions.discrete import GeneralizedPoisson +from pymc_experimental.distributions.histogram_utils import histogram_approximation +from pymc_experimental.distributions.multivariate import R2D2M2CP from pymc_experimental.distributions.timeseries import DiscreteMarkovChain __all__ = [ "DiscreteMarkovChain", "GeneralizedPoisson", "GenExtreme", + "R2D2M2CP", + "histogram_approximation", ] diff --git a/pymc_experimental/distributions/histogram_utils.py b/pymc_experimental/distributions/histogram_utils.py index d83813df..a20fccab 100644 --- a/pymc_experimental/distributions/histogram_utils.py +++ b/pymc_experimental/distributions/histogram_utils.py @@ -19,27 +19,21 @@ import pymc as pm from numpy.typing import ArrayLike -try: - import dask.array - import dask.dataframe -except ImportError: - dask = None - -try: - import xhistogram.core -except ImportError: - xhistogram = None - - __all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"] def quantile_histogram( data: ArrayLike, n_quantiles=1000, zero_inflation=False ) -> Dict[str, ArrayLike]: - if xhistogram is None: - raise RuntimeError("quantile_histogram requires xhistogram package") - + try: + import xhistogram.core + except ImportError as e: + raise RuntimeError("quantile_histogram requires xhistogram package") from e + try: + import dask.array + import dask.dataframe + except ImportError: + dask = None if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)): data = data.to_dask_array(lengths=True) if zero_inflation: @@ -74,8 +68,15 @@ def quantile_histogram( def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]: - if xhistogram is None: - raise RuntimeError("discrete_histogram requires xhistogram package") + try: + import xhistogram.core + except ImportError as e: + raise RuntimeError("discrete_histogram requires xhistogram package") from e + try: + import dask.array + import dask.dataframe + except ImportError: + dask = None if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)): data = data.to_dask_array(lengths=True) @@ -147,6 +148,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs): ... observed=measurements, n_quantiles=50, zero_inflation=True ... ) """ + try: + import dask.array + import dask.dataframe + except ImportError: + dask = None if dask and isinstance(observed, (dask.dataframe.Series, dask.dataframe.DataFrame)): observed = observed.to_dask_array(lengths=True) if np.issubdtype(observed.dtype, np.integer): diff --git a/pymc_experimental/distributions/multivariate/__init__.py b/pymc_experimental/distributions/multivariate/__init__.py new file mode 100644 index 00000000..64a79b24 --- /dev/null +++ b/pymc_experimental/distributions/multivariate/__init__.py @@ -0,0 +1 @@ +from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py new file mode 100644 index 00000000..fbdb01af --- /dev/null +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -0,0 +1,298 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Sequence, Union + +import pymc as pm +import pytensor.tensor as pt + +__all__ = ["R2D2M2CP"] + + +def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable): + pi = pt.erfinv(2 * psi - 1) + f = (1 / (2 * pi**2 + 1)) ** 0.5 + sigma = explained_var**0.5 * f + mu = sigma * pi * 2**0.5 + return mu, sigma + + +def _R2D2M2CP_beta( + name: str, + output_sigma: pt.TensorVariable, + input_sigma: pt.TensorVariable, + r2: pt.TensorVariable, + phi: pt.TensorVariable, + psi: pt.TensorVariable, + *, + dims: Union[str, Sequence[str]], + centered=False, +): + """R2D2M2CP beta prior. + + Parameters + ---------- + name: str + Name for the distribution + output_sigma: tensor + standard deviation of the outcome + input_sigma: tensor + standard deviation of the explanatory variables + r2: tensor + expected R2 for the linear regression + phi: tensor + variance weights that sums up to 1 + psi: tensor + probability of a coefficients to be positive + """ + tau2 = r2 / (1 - r2) + explained_variance = phi * pt.expand_dims(tau2 * output_sigma**2, -1) + mu_param, std_param = _psivar2musigma(psi, explained_variance) + if not centered: + with pm.Model(name): + raw = pm.Normal("raw", dims=dims) + beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims) + else: + beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims) + return beta + + +def R2D2M2CP( + name, + output_sigma, + input_sigma, + *, + dims, + r2, + variables_importance=None, + variance_explained=None, + r2_std=None, + positive_probs=0.5, + positive_probs_std=None, + centered=False, +): + """R2D2M2CP Prior. + + Parameters + ---------- + name : str + Name for the distribution + output_sigma : tensor + Output standard deviation + input_sigma : tensor + Input standard deviation + dims : Union[str, Sequence[str]] + Dims for the distribution + r2 : tensor + :math:`R^2` estimate + variables_importance : tensor, optional + Optional estimate for variables importance, positive, by default None + variance_explained : tensor, optional + Alternative estimate for variables importance which is point estimate of + variance explained, should sum up to one, by default None + r2_std : tensor, optional + Optional uncertainty over :math:`R^2`, by default None + positive_probs : tensor, optional + Optional probability of variables contribution to be positive, by default 0.5 + positive_probs_std : tensor, optional + Optional uncertainty over effect direction probability, by default None + centered : bool, optional + Centered or Non-Centered parametrization of the distribution, by default Non-Centered. Advised to check both + + Returns + ------- + residual_sigma, coefficients + Output variance (sigma squared) is split in residual variance and explained variance. + + Raises + ------ + TypeError + If parametrization is wrong. + + Notes + ----- + The R2D2M2CP prior is a modification of R2D2M2 prior. + + - ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132 + - R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine) + + Examples + -------- + Here are arguments explained in a synthetic example + + .. warning:: + + To use the prior in a linear regression + + - make sure :math:`X` is centered around zero + - intercept represents prior predictive mean when :math:`X` is centered + - setting named dims is required + + .. code-block:: python + + import pymc_experimental as pmx + import pymc as pm + import numpy as np + X = np.random.randn(10, 3) + b = np.random.randn(3) + y = X @ b + np.random.randn(10) * 0.04 + 5 + with pm.Model(coords=dict(variables=["a", "b", "c"])) as model: + eps, beta = pmx.distributions.R2D2M2CP( + "beta", + y.std(), + X.std(0), + dims="variables", + # NOTE: global shrinkage + r2=0.8, + # NOTE: if you are unsure about r2 + r2_std=0.2, + # NOTE: if you know where a variable should go + # if you do not know, leave as 0.5 + positive_probs=[0.8, 0.5, 0.1], + # NOTE: if you have different opinions about + # where a variable should go. + # NOTE: if you put 0.5 previously, + # just put 0.1 there, but other + # sigmas should work fine too + positive_probs_std=[0.3, 0.1, 0.2], + # NOTE: variable importances are relative to each other, + # but larget numbers put "more" weight in the relation + # use + # * 1-10 for small confidence + # * 10-30 for moderate confidence + # * 30+ for high confidence + # EXAMPLE: + # "a" - is likely to be useful + # "b" - no idea if it is useful + # "c" - a must have in the relation + variables_importance=[10, 1, 34], + # NOTE: try both + centered=True + ) + # intercept prior centering should be around prior predictive mean + intercept = y.mean() + # regressors should be centered around zero + Xc = X - X.mean(0) + obs = pm.Normal("obs", intercept + Xc @ beta, eps, observed=y) + + There can be special cases by choosing specific set of arguments + + Here the prior distribution of beta is ``Normal(0, y.std() * r2 ** .5)`` + + .. code-block:: python + + with pm.Model(coords=dict(variables=["a", "b", "c"])) as model: + eps, beta = pmx.distributions.R2D2M2CP( + "beta", + y.std(), + X.std(0), + dims="variables", + # NOTE: global shrinkage + r2=0.8, + # NOTE: if you are unsure about r2 + r2_std=0.2, + # NOTE: if you know where a variable should go + # if you do not know, leave as 0.5 + centered=False + ) + # intercept prior centering should be around prior predictive mean + intercept = y.mean() + # regressors should be centered around zero + Xc = X - X.mean(0) + obs = pm.Normal("obs", intercept + Xc @ beta, eps, observed=y) + + + It is fine to leave some of the ``_std`` arguments unspecified. + You can also specify only ``positive_probs``, and all + the variables are assumed to explain same amount of variance (same importance) + + .. code-block:: python + + with pm.Model(coords=dict(variables=["a", "b", "c"])) as model: + eps, beta = pmx.distributions.R2D2M2CP( + "beta", + y.std(), + X.std(0), + dims="variables", + # NOTE: global shrinkage + r2=0.8, + # NOTE: if you are unsure about r2 + r2_std=0.2, + # NOTE: if you know where a variable should go + # if you do not know, leave as 0.5 + positive_probs=[0.8, 0.5, 0.1], + # NOTE: try both + centered=True + ) + intercept = y.mean() + obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y) + + Notes + ----- + To reference R2D2M2CP implementation, you can use the following bibtex entry: + + .. code-block:: + + @misc{pymc-experimental-r2d2m2cp, + title = {pymc-devs/pymc-experimental: {P}ull {R}equest 137, {R2D2M2CP}}, + url = {https://github.com/pymc-devs/pymc-experimental/pull/137}, + author = {Max Kochurov}, + howpublished = {GitHub}, + year = {2023} + } + """ + if not isinstance(dims, (list, tuple)): + dims = (dims,) + *broadcast_dims, dim = dims + input_sigma = pt.as_tensor(input_sigma) + output_sigma = pt.as_tensor(output_sigma) + with pm.Model(name) as model: + if variables_importance is not None: + if variance_explained is not None: + raise TypeError("Can't use variable importance with variance explained") + if len(model.coords[dim]) <= 1: + raise TypeError("Can't use variable importance with less than two variables") + phi = pm.Dirichlet( + "phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim] + ) + elif variance_explained is not None: + if len(model.coords[dim]) <= 1: + raise TypeError("Can't use variance explained with less than two variables") + phi = pt.as_tensor(variance_explained) + else: + phi = 1 / len(model.coords[dim]) + if r2_std is not None: + r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims) + if positive_probs_std is not None: + psi = pm.Beta( + "psi", + mu=pt.as_tensor(positive_probs), + sigma=pt.as_tensor(positive_probs_std), + dims=broadcast_dims + [dim], + ) + else: + psi = pt.as_tensor(positive_probs) + beta = _R2D2M2CP_beta( + name, + output_sigma, + input_sigma, + r2, + phi, + psi, + dims=broadcast_dims + [dim], + centered=centered, + ) + resid_sigma = (1 - r2) ** 0.5 * output_sigma + return resid_sigma, beta diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py new file mode 100644 index 00000000..ee28c3e3 --- /dev/null +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -0,0 +1,165 @@ +import numpy as np +import pymc as pm +import pytest + +import pymc_experimental as pmx + + +class TestR2D2M2CP: + @pytest.fixture(autouse=True) + def model(self): + # every method is within a model + with pm.Model() as model: + yield model + + @pytest.fixture(params=[True, False], ids=["centered", "non-centered"]) + def centered(self, request): + return request.param + + @pytest.fixture(params=[["a"], ["a", "b"], ["one"]]) + def dims(self, model: pm.Model, request): + for i, c in enumerate(request.param): + if c == "one": + model.add_coord(c, range(1)) + else: + model.add_coord(c, range((i + 2) ** 2)) + return request.param + + @pytest.fixture + def input_shape(self, dims, model): + return [int(model.dim_lengths[d].eval()) for d in dims] + + @pytest.fixture + def output_shape(self, dims, model): + *hierarchy, _ = dims + return [int(model.dim_lengths[d].eval()) for d in hierarchy] + + @pytest.fixture + def input_std(self, input_shape): + return np.ones(input_shape) + + @pytest.fixture + def output_std(self, output_shape): + return np.ones(output_shape) + + @pytest.fixture + def r2(self): + return 0.8 + + @pytest.fixture(params=[None, 0.1], ids=["r2-std", "no-r2-std"]) + def r2_std(self, request): + return request.param + + @pytest.fixture(params=[True, False], ids=["probs", "no-probs"]) + def positive_probs(self, input_std, request): + if request.param: + return np.full_like(input_std, 0.5) + else: + return 0.5 + + @pytest.fixture(params=[True, False], ids=["probs-std", "no-probs-std"]) + def positive_probs_std(self, positive_probs, request): + if request.param: + return np.full_like(positive_probs, 0.1) + else: + return None + + @pytest.fixture(params=[None, "importance", "explained"]) + def phi_args(self, request, input_shape): + if input_shape[-1] < 2 and request.param is not None: + pytest.skip("not compatible") + elif request.param is None: + return {} + elif request.param == "importance": + return {"variables_importance": np.full(input_shape, 2)} + else: + val = np.full(input_shape, 2) + return {"variance_explained": val / val.sum(-1, keepdims=True)} + + def test_init( + self, + dims, + centered, + input_std, + output_std, + r2, + r2_std, + positive_probs, + positive_probs_std, + phi_args, + model: pm.Model, + ): + eps, beta = pmx.distributions.R2D2M2CP( + "beta", + output_std, + input_std, + dims=dims, + r2=r2, + r2_std=r2_std, + centered=centered, + positive_probs_std=positive_probs_std, + positive_probs=positive_probs, + **phi_args + ) + assert eps.eval().shape == output_std.shape + assert beta.eval().shape == input_std.shape + # r2 rv is only created if r2 std is not None + assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars) + # phi is only created if variable importances is not None and there is more than one var + assert ("beta::phi" in model.named_vars) == ("variables_importance" in phi_args), set( + model.named_vars + ) + assert ("beta::psi" in model.named_vars) == (positive_probs_std is not None), set( + model.named_vars + ) + + def test_failing_importance(self, dims, input_shape, output_std, input_std): + if input_shape[-1] < 2: + with pytest.raises(TypeError, match="less than two variables"): + pmx.distributions.R2D2M2CP( + "beta", + output_std, + input_std, + dims=dims, + r2=0.8, + variables_importance=abs(input_std), + ) + else: + pmx.distributions.R2D2M2CP( + "beta", + output_std, + input_std, + dims=dims, + r2=0.8, + variables_importance=abs(input_std), + ) + + def test_failing_variance_explained(self, dims, input_shape, output_std, input_std): + if input_shape[-1] < 2: + with pytest.raises(TypeError, match="less than two variables"): + pmx.distributions.R2D2M2CP( + "beta", + output_std, + input_std, + dims=dims, + r2=0.8, + variance_explained=abs(input_std), + ) + else: + pmx.distributions.R2D2M2CP( + "beta", output_std, input_std, dims=dims, r2=0.8, variance_explained=abs(input_std) + ) + + def test_failing_mutual_exclusive(self, model: pm.Model): + with pytest.raises(TypeError, match="variable importance with variance explained"): + with model: + model.add_coord("a", range(2)) + pmx.distributions.R2D2M2CP( + "beta", + 1, + [1, 1], + dims="a", + r2=0.8, + variance_explained=[0.5, 0.5], + variables_importance=[1, 1], + ) diff --git a/pyproject.toml b/pyproject.toml index 796ef376..ebc76bc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,12 @@ [tool.pytest.ini_options] minversion = "6.0" -xfail_strict=true +xfail_strict = true +addopts = [ + "-v", + "--doctest-modules", + "--ignore=pymc_experimental/model_builder.py" +] + [tool.black] line-length = 100