diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index 33f937078..0ecbe95d2 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -7,5 +7,7 @@ dependencies: - pytest-cov>=2.5 - pytest>=3.0 - python=3.8 +- dask +- xhistogram - pip: - "git+https://github.com/pymc-devs/pymc.git@main" diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index acdffaaaa..570af3855 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -7,5 +7,7 @@ dependencies: - pytest-cov>=2.5 - pytest>=3.0 - python=3.9 +- dask +- xhistogram - pip: - "git+https://github.com/pymc-devs/pymc.git@main" diff --git a/conda-envs/windows-environment-test-py37.yml b/conda-envs/windows-environment-test-py37.yml index cb511a5df..cde505bfe 100644 --- a/conda-envs/windows-environment-test-py37.yml +++ b/conda-envs/windows-environment-test-py37.yml @@ -7,5 +7,7 @@ dependencies: - pytest-cov>=2.5 - pytest>=3.0 - python=3.7 +- dask +- xhistogram - pip: - "git+https://github.com/pymc-devs/pymc.git@main" diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 516326ee1..ee408ac71 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -12,3 +12,4 @@ from pymc_experimental.bart import * +from pymc_experimental import distributions diff --git a/pymc_experimental/distributions/__init__.py b/pymc_experimental/distributions/__init__.py index e69de29bb..a477888f9 100644 --- a/pymc_experimental/distributions/__init__.py +++ b/pymc_experimental/distributions/__init__.py @@ -0,0 +1,2 @@ +from .histogram_utils import histogram_approximation +from . import histogram_utils diff --git a/pymc_experimental/distributions/histogram_utils.py b/pymc_experimental/distributions/histogram_utils.py new file mode 100644 index 000000000..a79ac88bc --- /dev/null +++ b/pymc_experimental/distributions/histogram_utils.py @@ -0,0 +1,128 @@ +import numpy as np +from numpy.typing import ArrayLike +from typing import Dict +import pymc as pm +import xhistogram.core + +try: + import dask.dataframe + import dask.array +except ImportError: + dask = None + + +__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"] + + +def quantile_histogram( + data: ArrayLike, n_quantiles=1000, zero_inflation=False +) -> Dict[str, ArrayLike]: + if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)): + data = data.to_dask_array(lengths=True) + if zero_inflation: + zeros = (data == 0).sum(0) + mdata = np.ma.masked_where(data == 0, data) + qdata = data[data > 0] + else: + mdata = data + qdata = data.flatten() + quantiles = np.percentile(qdata, np.linspace(0, 100, n_quantiles)) + if dask: + (quantiles,) = dask.compute(quantiles) + count, _ = xhistogram.core.histogram(mdata, bins=[quantiles], axis=0) + count = count.transpose(count.ndim - 1, *range(count.ndim - 1)) + lower = quantiles[:-1] + upper = quantiles[1:] + + if zero_inflation: + count = np.concatenate([zeros[None], count]) + lower = np.concatenate([[0], lower]) + upper = np.concatenate([[0], upper]) + lower = lower.reshape(lower.shape + (1,) * (count.ndim - 1)) + upper = upper.reshape(upper.shape + (1,) * (count.ndim - 1)) + + result = dict( + lower=lower, + upper=upper, + mid=(lower + upper) / 2, + count=count, + ) + return result + + +def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]: + if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)): + data = data.to_dask_array(lengths=True) + mid, count_uniq = np.unique(data, return_counts=True) + if min_count is not None: + mid = mid[count_uniq >= min_count] + count_uniq = count_uniq[count_uniq >= min_count] + bins = np.concatenate([mid, [mid.max() + 1]]) + if dask: + mid, bins = dask.compute(mid, bins) + count, _ = xhistogram.core.histogram(data, bins=[bins], axis=0) + count = count.transpose(count.ndim - 1, *range(count.ndim - 1)) + mid = mid.reshape(mid.shape + (1,) * (count.ndim - 1)) + return dict(mid=mid, count=count) + + +def histogram_approximation(name, dist, *, observed: ArrayLike, **h_kwargs): + """Approximate a distribution with a histogram potential. + + Parameters + ---------- + name : str + Name for the Potential + dist : aesara.tensor.var.TensorVariable + The output of pm.Distribution.dist() + observed : ArrayLike + Observed value to construct a histogram. Histogram is computed over 0th axis. Dask is supported. + + Returns + ------- + aesara.tensor.var.TensorVariable + Potential + + Examples + -------- + Discrete variables are reduced to unique repetitions (up to min_count) + >>> import pymc as pm + >>> import pymc_experimental as pmx + >>> production = np.random.poisson([1, 2, 5], size=(1000, 3)) + >>> with pm.Model(coords=dict(plant=range(3))): + ... lam = pm.Exponential("lam", 1.0, dims="plant") + ... pot = pmx.distributions.histogram_approximation( + ... "histogram_potential", pm.Poisson.dist(lam), observed=production, min_count=2 + ... ) + + Continuous variables are discretized into n_quantiles + >>> measurements = np.random.normal([1, 2, 3], [0.1, 0.4, 0.2], size=(10000, 3)) + >>> with pm.Model(coords=dict(tests=range(3))): + ... m = pm.Normal("m", dims="tests") + ... s = pm.LogNormal("s", dims="tests") + ... pot = pmx.distributions.histogram_approximation( + ... "histogram_potential", pm.Normal.dist(m, s), + ... observed=measurements, n_quantiles=50 + ... ) + + For special cases like Zero Inflation in Continuous variables there is a flag. + The flag adds a separate bin for zeros + >>> measurements = abs(measurements) + >>> measurements[100:] = 0 + >>> with pm.Model(coords=dict(tests=range(3))): + ... m = pm.Normal("m", dims="tests") + ... s = pm.LogNormal("s", dims="tests") + ... pot = pmx.distributions.histogram_approximation( + ... "histogram_potential", pm.Normal.dist(m, s), + ... observed=measurements, n_quantiles=50, zero_inflation=True + ... ) + """ + if dask and isinstance(observed, (dask.dataframe.Series, dask.dataframe.DataFrame)): + observed = observed.to_dask_array(lengths=True) + if np.issubdtype(observed.dtype, np.integer): + histogram = discrete_histogram(observed, **h_kwargs) + else: + histogram = quantile_histogram(observed, **h_kwargs) + if dask is not None: + (histogram,) = dask.compute(histogram) + return pm.Potential(name, pm.logp(dist, histogram["mid"]) * histogram["count"]) diff --git a/pymc_experimental/tests/test_histogram_approximation.py b/pymc_experimental/tests/test_histogram_approximation.py new file mode 100644 index 000000000..8684aaf44 --- /dev/null +++ b/pymc_experimental/tests/test_histogram_approximation.py @@ -0,0 +1,93 @@ +import pymc_experimental as pmx +import pymc as pm +import numpy as np +import pytest + + +@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format) +@pytest.mark.parametrize("zero_inflation", [True, False], ids="ZI={}".format) +@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format) +def test_histogram_init_cont(use_dask, zero_inflation, ndims): + data = np.random.randn(*(10000, *(2,) * (ndims - 1))) + if zero_inflation: + data = abs(data) + data[:100] = 0 + if use_dask: + dask = pytest.importorskip("dask") + dask_df = pytest.importorskip("dask.dataframe") + data = dask_df.from_array(data) + histogram = pmx.distributions.histogram_utils.quantile_histogram( + data, n_quantiles=100, zero_inflation=zero_inflation + ) + if use_dask: + (histogram,) = dask.compute(histogram) + assert isinstance(histogram, dict) + assert isinstance(histogram["mid"], np.ndarray) + assert np.issubdtype(histogram["mid"].dtype, np.floating) + size = 99 + zero_inflation + assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:]) + assert histogram["lower"].shape == (size,) + (1,) * len(data.shape[1:]) + assert histogram["upper"].shape == (size,) + (1,) * len(data.shape[1:]) + assert histogram["count"].shape == (size,) + data.shape[1:] + assert (histogram["count"].sum(0) == 10000).all() + if zero_inflation: + (histogram["count"][0] == 100).all() + + +@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format) +@pytest.mark.parametrize("min_count", [None, 5], ids="min_count={}".format) +@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format) +def test_histogram_init_discrete(use_dask, min_count, ndims): + data = np.random.randint(0, 100, size=(10000,) + (2,) * (ndims - 1)) + u, c = np.unique(data, return_counts=True) + if use_dask: + dask = pytest.importorskip("dask") + dask_df = pytest.importorskip("dask.dataframe") + data = dask_df.from_array(data) + histogram = pmx.distributions.histogram_utils.discrete_histogram(data, min_count=min_count) + if use_dask: + (histogram,) = dask.compute(histogram) + assert isinstance(histogram, dict) + assert isinstance(histogram["mid"], np.ndarray) + assert np.issubdtype(histogram["mid"].dtype, np.integer) + if min_count is not None: + size = int((c >= min_count).sum()) + else: + size = len(u) + assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:]) + assert histogram["count"].shape == (size,) + data.shape[1:] + if not min_count: + assert (histogram["count"].sum(0) == 10000).all() + + +@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format) +@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format) +def test_histogram_approx_cont(use_dask, ndims): + data = np.random.randn(*(10000, *(2,) * (ndims - 1))) + if use_dask: + dask = pytest.importorskip("dask") + dask_df = pytest.importorskip("dask.dataframe") + data = dask_df.from_array(data) + with pm.Model(): + m = pm.Normal("m") + s = pm.HalfNormal("s", size=2 if ndims > 1 else 1) + pot = pmx.distributions.histogram_approximation( + "histogram_potential", pm.Normal.dist(m, s), observed=data, n_quantiles=1000 + ) + trace = pm.sample(10, tune=0) # very fast + + +@pytest.mark.parametrize("use_dask", [True, False]) +@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format) +def test_histogram_approx_discrete(use_dask, ndims): + data = np.random.randint(0, 100, size=(10000, *(2,) * (ndims - 1))) + if use_dask: + dask = pytest.importorskip("dask") + dask_df = pytest.importorskip("dask.dataframe") + data = dask_df.from_array(data) + with pm.Model(): + s = pm.Exponential("s", 1.0, size=2 if ndims > 1 else 1) + pot = pmx.distributions.histogram_approximation( + "histogram_potential", pm.Poisson.dist(s), observed=data, min_count=10 + ) + trace = pm.sample(10, tune=0) # very fast diff --git a/pyproject.toml b/pyproject.toml index aa4e5e2bf..5dbea635d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,7 @@ [tool.pytest.ini_options] minversion = "6.0" xfail_strict=true +addopts = "--doctest-modules pymc_experimental" [tool.black] line-length = 100 diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..6c049906a --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1 @@ +dask[all] diff --git a/requirements.txt b/requirements.txt index 02ee67aba..c93f86d98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ pymc>=4.0.0b6 +xhistogram diff --git a/setup.py b/setup.py index 545c6f500..140aa8317 100644 --- a/setup.py +++ b/setup.py @@ -82,4 +82,5 @@ def get_version(): classifiers=classifiers, python_requires=">=3.8", install_requires=install_reqs, + extras_requires=dict(dask=["dask[all]"]), )