Skip to content

Commit 353e5a4

Browse files
committed
add dask function
1 parent 7cd5735 commit 353e5a4

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

pymc_experimental/tests/test_histogram_trick.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88
def test_histogram_init(use_dask):
99
data = np.random.randn(10000)
1010
if use_dask:
11+
dask = pytest.importorskip("dask")
1112
dask_df = pytest.importorskip("dask.dataframe")
1213
data = dask_df.from_array(data)
1314
histogram = pmx.utils.quantile_histogram(data, n_quantiles=100)
15+
if use_dask:
16+
(histogram,) = dask.compute(histogram)
1417
assert isinstance(histogram, dict)
18+
assert isinstance(histogram["mid"], np.ndarray)
1519
assert histogram["mid"].shape == (99,)
1620
assert histogram["low"].shape == (99,)
1721
assert histogram["upper"].shape == (99,)

pymc_experimental/utils/histogram_trick.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def _(data: ArrayLike, n_quantiles=1000) -> Dict[str, ArrayLike]:
3737

3838
@quantile_histogram.register(dask.dataframe.Series)
3939
def _(data: dask.dataframe.Series, n_quantiles=1000) -> Dict[str, ArrayLike]:
40-
quantiles = data.quantile(np.linspace(0, 1, n_quantiles)).persist()
41-
count, _ = dask.array.histogram(data, quantiles.compute())
42-
low = quantiles.to_dask_array(lengths=True)[:-1]
43-
upper = quantiles.to_dask_array(lengths=True)[1:]
40+
quantiles = data.quantile(np.linspace(0, 1, n_quantiles)).to_dask_array(lengths=True)
41+
count, _ = dask.array.histogram(data, quantiles)
42+
low = quantiles[:-1]
43+
upper = quantiles[1:]
4444
result = dict(
4545
low=low,
4646
upper=upper,

0 commit comments

Comments
 (0)