Skip to content

Histogram trick #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conda-envs/environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ dependencies:
- pytest-cov>=2.5
- pytest>=3.0
- python=3.8
- dask
- xhistogram
- pip:
- "git+https://github.com/pymc-devs/pymc.git@main"
2 changes: 2 additions & 0 deletions conda-envs/environment-test-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ dependencies:
- pytest-cov>=2.5
- pytest>=3.0
- python=3.9
- dask
- xhistogram
- pip:
- "git+https://github.com/pymc-devs/pymc.git@main"
2 changes: 2 additions & 0 deletions conda-envs/windows-environment-test-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ dependencies:
- pytest-cov>=2.5
- pytest>=3.0
- python=3.7
- dask
- xhistogram
- pip:
- "git+https://github.com/pymc-devs/pymc.git@main"
1 change: 1 addition & 0 deletions pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@


from pymc_experimental.bart import *
from pymc_experimental import distributions
2 changes: 2 additions & 0 deletions pymc_experimental/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .histogram_utils import histogram_approximation
from . import histogram_utils
128 changes: 128 additions & 0 deletions pymc_experimental/distributions/histogram_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import numpy as np
from numpy.typing import ArrayLike
from typing import Dict
import pymc as pm
import xhistogram.core

try:
import dask.dataframe
import dask.array
except ImportError:
dask = None


__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]


def quantile_histogram(
data: ArrayLike, n_quantiles=1000, zero_inflation=False
) -> Dict[str, ArrayLike]:
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
data = data.to_dask_array(lengths=True)
if zero_inflation:
zeros = (data == 0).sum(0)
mdata = np.ma.masked_where(data == 0, data)
qdata = data[data > 0]
else:
mdata = data
qdata = data.flatten()
quantiles = np.percentile(qdata, np.linspace(0, 100, n_quantiles))
if dask:
(quantiles,) = dask.compute(quantiles)
count, _ = xhistogram.core.histogram(mdata, bins=[quantiles], axis=0)
count = count.transpose(count.ndim - 1, *range(count.ndim - 1))
lower = quantiles[:-1]
upper = quantiles[1:]

if zero_inflation:
count = np.concatenate([zeros[None], count])
lower = np.concatenate([[0], lower])
upper = np.concatenate([[0], upper])
lower = lower.reshape(lower.shape + (1,) * (count.ndim - 1))
upper = upper.reshape(upper.shape + (1,) * (count.ndim - 1))

result = dict(
lower=lower,
upper=upper,
mid=(lower + upper) / 2,
count=count,
)
return result


def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
data = data.to_dask_array(lengths=True)
mid, count_uniq = np.unique(data, return_counts=True)
if min_count is not None:
mid = mid[count_uniq >= min_count]
count_uniq = count_uniq[count_uniq >= min_count]
bins = np.concatenate([mid, [mid.max() + 1]])
if dask:
mid, bins = dask.compute(mid, bins)
count, _ = xhistogram.core.histogram(data, bins=[bins], axis=0)
count = count.transpose(count.ndim - 1, *range(count.ndim - 1))
mid = mid.reshape(mid.shape + (1,) * (count.ndim - 1))
return dict(mid=mid, count=count)


def histogram_approximation(name, dist, *, observed: ArrayLike, **h_kwargs):
"""Approximate a distribution with a histogram potential.

Parameters
----------
name : str
Name for the Potential
dist : aesara.tensor.var.TensorVariable
The output of pm.Distribution.dist()
observed : ArrayLike
Observed value to construct a histogram. Histogram is computed over 0th axis. Dask is supported.

Returns
-------
aesara.tensor.var.TensorVariable
Potential

Examples
--------
Discrete variables are reduced to unique repetitions (up to min_count)
>>> import pymc as pm
>>> import pymc_experimental as pmx
>>> production = np.random.poisson([1, 2, 5], size=(1000, 3))
>>> with pm.Model(coords=dict(plant=range(3))):
... lam = pm.Exponential("lam", 1.0, dims="plant")
... pot = pmx.distributions.histogram_approximation(
... "histogram_potential", pm.Poisson.dist(lam), observed=production, min_count=2
... )

Continuous variables are discretized into n_quantiles
>>> measurements = np.random.normal([1, 2, 3], [0.1, 0.4, 0.2], size=(10000, 3))
>>> with pm.Model(coords=dict(tests=range(3))):
... m = pm.Normal("m", dims="tests")
... s = pm.LogNormal("s", dims="tests")
... pot = pmx.distributions.histogram_approximation(
... "histogram_potential", pm.Normal.dist(m, s),
... observed=measurements, n_quantiles=50
... )

For special cases like Zero Inflation in Continuous variables there is a flag.
The flag adds a separate bin for zeros
>>> measurements = abs(measurements)
>>> measurements[100:] = 0
>>> with pm.Model(coords=dict(tests=range(3))):
... m = pm.Normal("m", dims="tests")
... s = pm.LogNormal("s", dims="tests")
... pot = pmx.distributions.histogram_approximation(
... "histogram_potential", pm.Normal.dist(m, s),
... observed=measurements, n_quantiles=50, zero_inflation=True
... )
"""
if dask and isinstance(observed, (dask.dataframe.Series, dask.dataframe.DataFrame)):
observed = observed.to_dask_array(lengths=True)
if np.issubdtype(observed.dtype, np.integer):
histogram = discrete_histogram(observed, **h_kwargs)
else:
histogram = quantile_histogram(observed, **h_kwargs)
if dask is not None:
(histogram,) = dask.compute(histogram)
return pm.Potential(name, pm.logp(dist, histogram["mid"]) * histogram["count"])
93 changes: 93 additions & 0 deletions pymc_experimental/tests/test_histogram_approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pymc_experimental as pmx
import pymc as pm
import numpy as np
import pytest


@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)
@pytest.mark.parametrize("zero_inflation", [True, False], ids="ZI={}".format)
@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
def test_histogram_init_cont(use_dask, zero_inflation, ndims):
data = np.random.randn(*(10000, *(2,) * (ndims - 1)))
if zero_inflation:
data = abs(data)
data[:100] = 0
if use_dask:
dask = pytest.importorskip("dask")
dask_df = pytest.importorskip("dask.dataframe")
data = dask_df.from_array(data)
histogram = pmx.distributions.histogram_utils.quantile_histogram(
data, n_quantiles=100, zero_inflation=zero_inflation
)
if use_dask:
(histogram,) = dask.compute(histogram)
assert isinstance(histogram, dict)
assert isinstance(histogram["mid"], np.ndarray)
assert np.issubdtype(histogram["mid"].dtype, np.floating)
size = 99 + zero_inflation
assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:])
assert histogram["lower"].shape == (size,) + (1,) * len(data.shape[1:])
assert histogram["upper"].shape == (size,) + (1,) * len(data.shape[1:])
assert histogram["count"].shape == (size,) + data.shape[1:]
assert (histogram["count"].sum(0) == 10000).all()
if zero_inflation:
(histogram["count"][0] == 100).all()


@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)
@pytest.mark.parametrize("min_count", [None, 5], ids="min_count={}".format)
@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
def test_histogram_init_discrete(use_dask, min_count, ndims):
data = np.random.randint(0, 100, size=(10000,) + (2,) * (ndims - 1))
u, c = np.unique(data, return_counts=True)
if use_dask:
dask = pytest.importorskip("dask")
dask_df = pytest.importorskip("dask.dataframe")
data = dask_df.from_array(data)
histogram = pmx.distributions.histogram_utils.discrete_histogram(data, min_count=min_count)
if use_dask:
(histogram,) = dask.compute(histogram)
assert isinstance(histogram, dict)
assert isinstance(histogram["mid"], np.ndarray)
assert np.issubdtype(histogram["mid"].dtype, np.integer)
if min_count is not None:
size = int((c >= min_count).sum())
else:
size = len(u)
assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:])
assert histogram["count"].shape == (size,) + data.shape[1:]
if not min_count:
assert (histogram["count"].sum(0) == 10000).all()


@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)
@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
def test_histogram_approx_cont(use_dask, ndims):
data = np.random.randn(*(10000, *(2,) * (ndims - 1)))
if use_dask:
dask = pytest.importorskip("dask")
dask_df = pytest.importorskip("dask.dataframe")
data = dask_df.from_array(data)
with pm.Model():
m = pm.Normal("m")
s = pm.HalfNormal("s", size=2 if ndims > 1 else 1)
pot = pmx.distributions.histogram_approximation(
"histogram_potential", pm.Normal.dist(m, s), observed=data, n_quantiles=1000
)
trace = pm.sample(10, tune=0) # very fast


@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
def test_histogram_approx_discrete(use_dask, ndims):
data = np.random.randint(0, 100, size=(10000, *(2,) * (ndims - 1)))
if use_dask:
dask = pytest.importorskip("dask")
dask_df = pytest.importorskip("dask.dataframe")
data = dask_df.from_array(data)
with pm.Model():
s = pm.Exponential("s", 1.0, size=2 if ndims > 1 else 1)
pot = pmx.distributions.histogram_approximation(
"histogram_potential", pm.Poisson.dist(s), observed=data, min_count=10
)
trace = pm.sample(10, tune=0) # very fast
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[tool.pytest.ini_options]
minversion = "6.0"
xfail_strict=true
addopts = "--doctest-modules pymc_experimental"

[tool.black]
line-length = 100
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dask[all]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pymc>=4.0.0b6
xhistogram
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ def get_version():
classifiers=classifiers,
python_requires=">=3.8",
install_requires=install_reqs,
extras_requires=dict(dask=["dask[all]"]),
)