Skip to content

Commit f50600b

Browse files
authored
Histogram trick (#38)
* add quantile histogram helper function * rename module to utils * add dask as dev dependency * add thistogram for dask * add dask function * add histogram approximation * renames * renaming * move files around * add tests for discrete cases * add discrete case * support series and arrays * add zero inflation for numpy * fix tests with more generic asserts * add zero inflation to dask implementation * argumants for the base function * simplify implementation * add docs * update requirements * test 2d inputs * fix bugs * add docs * move doctest to workflows * newline * adjust pipelines * add __all__ entry * update test envs
1 parent efa57dd commit f50600b

11 files changed

+234
-0
lines changed

conda-envs/environment-test-py38.yml

+2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ dependencies:
77
- pytest-cov>=2.5
88
- pytest>=3.0
99
- python=3.8
10+
- dask
11+
- xhistogram
1012
- pip:
1113
- "git+https://github.com/pymc-devs/pymc.git@main"

conda-envs/environment-test-py39.yml

+2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ dependencies:
77
- pytest-cov>=2.5
88
- pytest>=3.0
99
- python=3.9
10+
- dask
11+
- xhistogram
1012
- pip:
1113
- "git+https://github.com/pymc-devs/pymc.git@main"

conda-envs/windows-environment-test-py37.yml

+2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ dependencies:
77
- pytest-cov>=2.5
88
- pytest>=3.0
99
- python=3.7
10+
- dask
11+
- xhistogram
1012
- pip:
1113
- "git+https://github.com/pymc-devs/pymc.git@main"

pymc_experimental/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212

1313

1414
from pymc_experimental.bart import *
15+
from pymc_experimental import distributions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .histogram_utils import histogram_approximation
2+
from . import histogram_utils
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import numpy as np
2+
from numpy.typing import ArrayLike
3+
from typing import Dict
4+
import pymc as pm
5+
import xhistogram.core
6+
7+
try:
8+
import dask.dataframe
9+
import dask.array
10+
except ImportError:
11+
dask = None
12+
13+
14+
__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
15+
16+
17+
def quantile_histogram(
18+
data: ArrayLike, n_quantiles=1000, zero_inflation=False
19+
) -> Dict[str, ArrayLike]:
20+
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
21+
data = data.to_dask_array(lengths=True)
22+
if zero_inflation:
23+
zeros = (data == 0).sum(0)
24+
mdata = np.ma.masked_where(data == 0, data)
25+
qdata = data[data > 0]
26+
else:
27+
mdata = data
28+
qdata = data.flatten()
29+
quantiles = np.percentile(qdata, np.linspace(0, 100, n_quantiles))
30+
if dask:
31+
(quantiles,) = dask.compute(quantiles)
32+
count, _ = xhistogram.core.histogram(mdata, bins=[quantiles], axis=0)
33+
count = count.transpose(count.ndim - 1, *range(count.ndim - 1))
34+
lower = quantiles[:-1]
35+
upper = quantiles[1:]
36+
37+
if zero_inflation:
38+
count = np.concatenate([zeros[None], count])
39+
lower = np.concatenate([[0], lower])
40+
upper = np.concatenate([[0], upper])
41+
lower = lower.reshape(lower.shape + (1,) * (count.ndim - 1))
42+
upper = upper.reshape(upper.shape + (1,) * (count.ndim - 1))
43+
44+
result = dict(
45+
lower=lower,
46+
upper=upper,
47+
mid=(lower + upper) / 2,
48+
count=count,
49+
)
50+
return result
51+
52+
53+
def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
54+
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
55+
data = data.to_dask_array(lengths=True)
56+
mid, count_uniq = np.unique(data, return_counts=True)
57+
if min_count is not None:
58+
mid = mid[count_uniq >= min_count]
59+
count_uniq = count_uniq[count_uniq >= min_count]
60+
bins = np.concatenate([mid, [mid.max() + 1]])
61+
if dask:
62+
mid, bins = dask.compute(mid, bins)
63+
count, _ = xhistogram.core.histogram(data, bins=[bins], axis=0)
64+
count = count.transpose(count.ndim - 1, *range(count.ndim - 1))
65+
mid = mid.reshape(mid.shape + (1,) * (count.ndim - 1))
66+
return dict(mid=mid, count=count)
67+
68+
69+
def histogram_approximation(name, dist, *, observed: ArrayLike, **h_kwargs):
70+
"""Approximate a distribution with a histogram potential.
71+
72+
Parameters
73+
----------
74+
name : str
75+
Name for the Potential
76+
dist : aesara.tensor.var.TensorVariable
77+
The output of pm.Distribution.dist()
78+
observed : ArrayLike
79+
Observed value to construct a histogram. Histogram is computed over 0th axis. Dask is supported.
80+
81+
Returns
82+
-------
83+
aesara.tensor.var.TensorVariable
84+
Potential
85+
86+
Examples
87+
--------
88+
Discrete variables are reduced to unique repetitions (up to min_count)
89+
>>> import pymc as pm
90+
>>> import pymc_experimental as pmx
91+
>>> production = np.random.poisson([1, 2, 5], size=(1000, 3))
92+
>>> with pm.Model(coords=dict(plant=range(3))):
93+
... lam = pm.Exponential("lam", 1.0, dims="plant")
94+
... pot = pmx.distributions.histogram_approximation(
95+
... "histogram_potential", pm.Poisson.dist(lam), observed=production, min_count=2
96+
... )
97+
98+
Continuous variables are discretized into n_quantiles
99+
>>> measurements = np.random.normal([1, 2, 3], [0.1, 0.4, 0.2], size=(10000, 3))
100+
>>> with pm.Model(coords=dict(tests=range(3))):
101+
... m = pm.Normal("m", dims="tests")
102+
... s = pm.LogNormal("s", dims="tests")
103+
... pot = pmx.distributions.histogram_approximation(
104+
... "histogram_potential", pm.Normal.dist(m, s),
105+
... observed=measurements, n_quantiles=50
106+
... )
107+
108+
For special cases like Zero Inflation in Continuous variables there is a flag.
109+
The flag adds a separate bin for zeros
110+
>>> measurements = abs(measurements)
111+
>>> measurements[100:] = 0
112+
>>> with pm.Model(coords=dict(tests=range(3))):
113+
... m = pm.Normal("m", dims="tests")
114+
... s = pm.LogNormal("s", dims="tests")
115+
... pot = pmx.distributions.histogram_approximation(
116+
... "histogram_potential", pm.Normal.dist(m, s),
117+
... observed=measurements, n_quantiles=50, zero_inflation=True
118+
... )
119+
"""
120+
if dask and isinstance(observed, (dask.dataframe.Series, dask.dataframe.DataFrame)):
121+
observed = observed.to_dask_array(lengths=True)
122+
if np.issubdtype(observed.dtype, np.integer):
123+
histogram = discrete_histogram(observed, **h_kwargs)
124+
else:
125+
histogram = quantile_histogram(observed, **h_kwargs)
126+
if dask is not None:
127+
(histogram,) = dask.compute(histogram)
128+
return pm.Potential(name, pm.logp(dist, histogram["mid"]) * histogram["count"])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import pymc_experimental as pmx
2+
import pymc as pm
3+
import numpy as np
4+
import pytest
5+
6+
7+
@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)
8+
@pytest.mark.parametrize("zero_inflation", [True, False], ids="ZI={}".format)
9+
@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
10+
def test_histogram_init_cont(use_dask, zero_inflation, ndims):
11+
data = np.random.randn(*(10000, *(2,) * (ndims - 1)))
12+
if zero_inflation:
13+
data = abs(data)
14+
data[:100] = 0
15+
if use_dask:
16+
dask = pytest.importorskip("dask")
17+
dask_df = pytest.importorskip("dask.dataframe")
18+
data = dask_df.from_array(data)
19+
histogram = pmx.distributions.histogram_utils.quantile_histogram(
20+
data, n_quantiles=100, zero_inflation=zero_inflation
21+
)
22+
if use_dask:
23+
(histogram,) = dask.compute(histogram)
24+
assert isinstance(histogram, dict)
25+
assert isinstance(histogram["mid"], np.ndarray)
26+
assert np.issubdtype(histogram["mid"].dtype, np.floating)
27+
size = 99 + zero_inflation
28+
assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:])
29+
assert histogram["lower"].shape == (size,) + (1,) * len(data.shape[1:])
30+
assert histogram["upper"].shape == (size,) + (1,) * len(data.shape[1:])
31+
assert histogram["count"].shape == (size,) + data.shape[1:]
32+
assert (histogram["count"].sum(0) == 10000).all()
33+
if zero_inflation:
34+
(histogram["count"][0] == 100).all()
35+
36+
37+
@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)
38+
@pytest.mark.parametrize("min_count", [None, 5], ids="min_count={}".format)
39+
@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
40+
def test_histogram_init_discrete(use_dask, min_count, ndims):
41+
data = np.random.randint(0, 100, size=(10000,) + (2,) * (ndims - 1))
42+
u, c = np.unique(data, return_counts=True)
43+
if use_dask:
44+
dask = pytest.importorskip("dask")
45+
dask_df = pytest.importorskip("dask.dataframe")
46+
data = dask_df.from_array(data)
47+
histogram = pmx.distributions.histogram_utils.discrete_histogram(data, min_count=min_count)
48+
if use_dask:
49+
(histogram,) = dask.compute(histogram)
50+
assert isinstance(histogram, dict)
51+
assert isinstance(histogram["mid"], np.ndarray)
52+
assert np.issubdtype(histogram["mid"].dtype, np.integer)
53+
if min_count is not None:
54+
size = int((c >= min_count).sum())
55+
else:
56+
size = len(u)
57+
assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:])
58+
assert histogram["count"].shape == (size,) + data.shape[1:]
59+
if not min_count:
60+
assert (histogram["count"].sum(0) == 10000).all()
61+
62+
63+
@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)
64+
@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
65+
def test_histogram_approx_cont(use_dask, ndims):
66+
data = np.random.randn(*(10000, *(2,) * (ndims - 1)))
67+
if use_dask:
68+
dask = pytest.importorskip("dask")
69+
dask_df = pytest.importorskip("dask.dataframe")
70+
data = dask_df.from_array(data)
71+
with pm.Model():
72+
m = pm.Normal("m")
73+
s = pm.HalfNormal("s", size=2 if ndims > 1 else 1)
74+
pot = pmx.distributions.histogram_approximation(
75+
"histogram_potential", pm.Normal.dist(m, s), observed=data, n_quantiles=1000
76+
)
77+
trace = pm.sample(10, tune=0) # very fast
78+
79+
80+
@pytest.mark.parametrize("use_dask", [True, False])
81+
@pytest.mark.parametrize("ndims", [1, 2], ids="ndims={}".format)
82+
def test_histogram_approx_discrete(use_dask, ndims):
83+
data = np.random.randint(0, 100, size=(10000, *(2,) * (ndims - 1)))
84+
if use_dask:
85+
dask = pytest.importorskip("dask")
86+
dask_df = pytest.importorskip("dask.dataframe")
87+
data = dask_df.from_array(data)
88+
with pm.Model():
89+
s = pm.Exponential("s", 1.0, size=2 if ndims > 1 else 1)
90+
pot = pmx.distributions.histogram_approximation(
91+
"histogram_potential", pm.Poisson.dist(s), observed=data, min_count=10
92+
)
93+
trace = pm.sample(10, tune=0) # very fast

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[tool.pytest.ini_options]
22
minversion = "6.0"
33
xfail_strict=true
4+
addopts = "--doctest-modules pymc_experimental"
45

56
[tool.black]
67
line-length = 100

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
dask[all]

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
pymc>=4.0.0b6
2+
xhistogram

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,5 @@ def get_version():
8282
classifiers=classifiers,
8383
python_requires=">=3.8",
8484
install_requires=install_reqs,
85+
extras_requires=dict(dask=["dask[all]"]),
8586
)

0 commit comments

Comments
 (0)