Skip to content

Commit 29ee733

Browse files
authored
add histogram appoximation to api reference (#49)
1 parent 3e78c36 commit 29ee733

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

docs/api_reference.rst

+7
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,10 @@ methods in the current release of PyMC experimental.
1515
.. automodule:: pymc_experimental.bart
1616
:members: BART, PGBART, plot_dependence, plot_variable_importance, predict
1717

18+
19+
:mod:`pymc_experimental.distributions`
20+
=============================
21+
22+
.. automodule:: pymc_experimental.distributions.histogram_utils
23+
:members: histogram_approximation
24+

pymc_experimental/distributions/histogram_utils.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
6666
return dict(mid=mid, count=count)
6767

6868

69-
def histogram_approximation(name, dist, *, observed: ArrayLike, **h_kwargs):
69+
def histogram_approximation(name, dist, *, observed, **h_kwargs):
7070
"""Approximate a distribution with a histogram potential.
7171
7272
Parameters
@@ -76,7 +76,8 @@ def histogram_approximation(name, dist, *, observed: ArrayLike, **h_kwargs):
7676
dist : aesara.tensor.var.TensorVariable
7777
The output of pm.Distribution.dist()
7878
observed : ArrayLike
79-
Observed value to construct a histogram. Histogram is computed over 0th axis. Dask is supported.
79+
Observed value to construct a histogram. Histogram is computed over 0th axis.
80+
Dask is supported.
8081
8182
Returns
8283
-------
@@ -86,34 +87,37 @@ def histogram_approximation(name, dist, *, observed: ArrayLike, **h_kwargs):
8687
Examples
8788
--------
8889
Discrete variables are reduced to unique repetitions (up to min_count)
90+
8991
>>> import pymc as pm
9092
>>> import pymc_experimental as pmx
9193
>>> production = np.random.poisson([1, 2, 5], size=(1000, 3))
9294
>>> with pm.Model(coords=dict(plant=range(3))):
9395
... lam = pm.Exponential("lam", 1.0, dims="plant")
9496
... pot = pmx.distributions.histogram_approximation(
95-
... "histogram_potential", pm.Poisson.dist(lam), observed=production, min_count=2
97+
... "pot", pm.Poisson.dist(lam), observed=production, min_count=2
9698
... )
9799
98100
Continuous variables are discretized into n_quantiles
101+
99102
>>> measurements = np.random.normal([1, 2, 3], [0.1, 0.4, 0.2], size=(10000, 3))
100103
>>> with pm.Model(coords=dict(tests=range(3))):
101104
... m = pm.Normal("m", dims="tests")
102105
... s = pm.LogNormal("s", dims="tests")
103106
... pot = pmx.distributions.histogram_approximation(
104-
... "histogram_potential", pm.Normal.dist(m, s),
107+
... "pot", pm.Normal.dist(m, s),
105108
... observed=measurements, n_quantiles=50
106109
... )
107110
108111
For special cases like Zero Inflation in Continuous variables there is a flag.
109112
The flag adds a separate bin for zeros
113+
110114
>>> measurements = abs(measurements)
111115
>>> measurements[100:] = 0
112116
>>> with pm.Model(coords=dict(tests=range(3))):
113117
... m = pm.Normal("m", dims="tests")
114118
... s = pm.LogNormal("s", dims="tests")
115119
... pot = pmx.distributions.histogram_approximation(
116-
... "histogram_potential", pm.Normal.dist(m, s),
120+
... "pot", pm.Normal.dist(m, s),
117121
... observed=measurements, n_quantiles=50, zero_inflation=True
118122
... )
119123
"""

0 commit comments

Comments
 (0)