Skip to content

Commit 5924880

Browse files
committed
Fixed pymc-devs#3310. Added broadcast_distribution_samples, which helps broadcasting multiple rvs calls with different size and distribution parameter shapes. Added shape guards to other continuous distributions.
1 parent 081e7f4 commit 5924880

File tree

5 files changed

+102
-13
lines changed

5 files changed

+102
-13
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ This will be the last release to support Python 2.
3636
- Fixed `Rice` distribution, which inconsistently mixed two parametrizations (#3286).
3737
- `Rice` distribution now accepts multiple parameters and observations and is usable with NUTS (#3289).
3838
- `sample_posterior_predictive` no longer calls `draw_values` to initialize the shape of the ppc trace. This called could lead to `ValueError`'s when sampling the ppc from a model with `Flat` or `HalfFlat` prior distributions (Fix issue #3294).
39+
- Added the `broadcast_distribution_samples` function that helps broadcasting arrays of drawn samples, taking into account the requested `size` and the inferred distribution shape. This sometimes is needed by distributions that call several `rvs` separately within their `random` method, such as the `ZeroInflatedPoisson` (Fix issue #3310).
40+
- The `Wald`, `Kumaraswamy`, `LogNormal`, `Pareto`, `Cauchy`, `HalfCauchy`, `Weibull` and `ExGaussian` distributions `random` method used a hidden `_random` function that was written with scalars in mind. This could potentially lead to artificial correlations between random draws. Added shape guards and broadcasting of the distribution samples to prevent this (Similar to issue #3310).
3941

4042

4143
### Deprecations

pymc3/distributions/continuous.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
alltrue_elemwise, betaln, bound, gammaln, i0e, incomplete_beta, logpow,
2222
normal_lccdf, normal_lcdf, SplineWrapper, std_cdf, zvalue,
2323
)
24-
from .distribution import Continuous, draw_values, generate_samples
24+
from .distribution import (Continuous, draw_values, generate_samples,
25+
to_tuple, broadcast_distribution_samples)
2526

2627
__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'TruncatedNormal', 'Beta',
2728
'Kumaraswamy', 'Exponential', 'Laplace', 'StudentT', 'Cauchy',
@@ -937,10 +938,15 @@ def get_mu_lam_phi(self, mu, lam, phi):
937938
'mu and lam, mu and phi, or lam and phi.')
938939

939940
def _random(self, mu, lam, alpha, size=None):
940-
v = np.random.normal(size=size)**2
941+
_size = alpha.shape
942+
if size is not None:
943+
size = to_tuple(size)
944+
if _size[:len(size)] != size:
945+
_size = size + _size
946+
v = np.random.normal(size=_size)**2
941947
value = (mu + (mu**2) * v / (2. * lam) - mu / (2. * lam)
942948
* np.sqrt(4. * mu * lam * v + (mu * v)**2))
943-
z = np.random.uniform(size=size)
949+
z = np.random.uniform(size=_size)
944950
i = np.floor(z - mu / (mu + value)) * 2 + 1
945951
value = (value**-i) * (mu**(i + 1))
946952
return value + alpha
@@ -964,6 +970,8 @@ def random(self, point=None, size=None):
964970
"""
965971
mu, lam, alpha = draw_values([self.mu, self.lam, self.alpha],
966972
point=point, size=size)
973+
mu, lam, alpha = broadcast_distribution_samples([mu, lam, alpha],
974+
size=size)
967975
return generate_samples(self._random,
968976
mu, lam, alpha,
969977
dist_shape=self.shape,
@@ -1270,7 +1278,12 @@ def __init__(self, a, b, *args, **kwargs):
12701278
assert_negative_support(b, 'b', 'Kumaraswamy')
12711279

12721280
def _random(self, a, b, size=None):
1273-
u = np.random.uniform(size=size)
1281+
_size = a.shape
1282+
if size is not None:
1283+
size = to_tuple(size)
1284+
if _size[:len(size)] != size:
1285+
_size = size + _size
1286+
u = np.random.uniform(size=_size)
12741287
return (1 - (1 - u) ** (1 / b)) ** (1 / a)
12751288

12761289
def random(self, point=None, size=None):
@@ -1292,6 +1305,7 @@ def random(self, point=None, size=None):
12921305
"""
12931306
a, b = draw_values([self.a, self.b],
12941307
point=point, size=size)
1308+
a, b = broadcast_distribution_samples([a, b], size=size)
12951309
return generate_samples(self._random, a, b,
12961310
dist_shape=self.shape,
12971311
size=size)
@@ -1644,7 +1658,12 @@ def __init__(self, mu=0, sd=None, tau=None, *args, **kwargs):
16441658
assert_negative_support(sd, 'sd', 'Lognormal')
16451659

16461660
def _random(self, mu, tau, size=None):
1647-
samples = np.random.normal(size=size)
1661+
_size = tau.shape
1662+
if size is not None:
1663+
size = to_tuple(size)
1664+
if _size[:len(size)] != size:
1665+
_size = size + _size
1666+
samples = np.random.normal(size=_size)
16481667
return np.exp(mu + (tau**-0.5) * samples)
16491668

16501669
def random(self, point=None, size=None):
@@ -1665,6 +1684,7 @@ def random(self, point=None, size=None):
16651684
array
16661685
"""
16671686
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
1687+
mu, tau = broadcast_distribution_samples([mu, tau], size=size)
16681688
return generate_samples(self._random, mu, tau,
16691689
dist_shape=self.shape,
16701690
size=size)
@@ -1930,7 +1950,12 @@ def __init__(self, alpha, m, transform='lowerbound', *args, **kwargs):
19301950
super(Pareto, self).__init__(transform=transform, *args, **kwargs)
19311951

19321952
def _random(self, alpha, m, size=None):
1933-
u = np.random.uniform(size=size)
1953+
_size = alpha.shape
1954+
if size is not None:
1955+
size = to_tuple(size)
1956+
if _size[:len(size)] != size:
1957+
_size = size + _size
1958+
u = np.random.uniform(size=_size)
19341959
return m * (1. - u)**(-1. / alpha)
19351960

19361961
def random(self, point=None, size=None):
@@ -1952,6 +1977,7 @@ def random(self, point=None, size=None):
19521977
"""
19531978
alpha, m = draw_values([self.alpha, self.m],
19541979
point=point, size=size)
1980+
alpha, m = broadcast_distribution_samples([alpha, m], size=size)
19551981
return generate_samples(self._random, alpha, m,
19561982
dist_shape=self.shape,
19571983
size=size)
@@ -2054,7 +2080,12 @@ def __init__(self, alpha, beta, *args, **kwargs):
20542080
assert_negative_support(beta, 'beta', 'Cauchy')
20552081

20562082
def _random(self, alpha, beta, size=None):
2057-
u = np.random.uniform(size=size)
2083+
_size = alpha.shape
2084+
if size is not None:
2085+
size = to_tuple(size)
2086+
if _size[:len(size)] != size:
2087+
_size = size + _size
2088+
u = np.random.uniform(size=_size)
20582089
return alpha + beta * np.tan(np.pi * (u - 0.5))
20592090

20602091
def random(self, point=None, size=None):
@@ -2076,6 +2107,7 @@ def random(self, point=None, size=None):
20762107
"""
20772108
alpha, beta = draw_values([self.alpha, self.beta],
20782109
point=point, size=size)
2110+
alpha, beta = broadcast_distribution_samples([alpha, beta], size=size)
20792111
return generate_samples(self._random, alpha, beta,
20802112
dist_shape=self.shape,
20812113
size=size)
@@ -2163,7 +2195,12 @@ def __init__(self, beta, *args, **kwargs):
21632195
assert_negative_support(beta, 'beta', 'HalfCauchy')
21642196

21652197
def _random(self, beta, size=None):
2166-
u = np.random.uniform(size=size)
2198+
_size = beta.shape
2199+
if size is not None:
2200+
size = to_tuple(size)
2201+
if _size[:len(size)] != size:
2202+
_size = size + _size
2203+
u = np.random.uniform(size=_size)
21672204
return beta * np.abs(np.tan(np.pi * (u - 0.5)))
21682205

21692206
def random(self, point=None, size=None):
@@ -2637,9 +2674,15 @@ def random(self, point=None, size=None):
26372674
"""
26382675
alpha, beta = draw_values([self.alpha, self.beta],
26392676
point=point, size=size)
2677+
alpha, beta = broadcast_distribution_samples([alpha, beta], size=size)
26402678

26412679
def _random(a, b, size=None):
2642-
return b * (-np.log(np.random.uniform(size=size)))**(1 / a)
2680+
_size = a.shape
2681+
if size is not None:
2682+
size = to_tuple(size)
2683+
if _size[:len(size)] != size:
2684+
_size = size + _size
2685+
return b * (-np.log(np.random.uniform(size=_size)))**(1 / a)
26432686

26442687
return generate_samples(_random, alpha, beta,
26452688
dist_shape=self.shape,
@@ -2921,6 +2964,8 @@ def random(self, point=None, size=None):
29212964
"""
29222965
mu, sigma, nu = draw_values([self.mu, self.sigma, self.nu],
29232966
point=point, size=size)
2967+
mu, sigma, nu = broadcast_distribution_samples([mu, sigma, nu],
2968+
size=size)
29242969

29252970
def _random(mu, sigma, nu, size=None):
29262971
return (np.random.normal(mu, sigma, size=size)

pymc3/distributions/discrete.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from pymc3.util import get_variable_name
88
from .dist_math import bound, factln, binomln, betaln, logpow, random_choice
9-
from .distribution import Discrete, draw_values, generate_samples
9+
from .distribution import (Discrete, draw_values, generate_samples,
10+
broadcast_distribution_samples)
1011
from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp
1112

1213

@@ -345,6 +346,7 @@ def _ppf(self, p):
345346

346347
def _random(self, q, beta, size=None):
347348
p = np.random.uniform(size=size)
349+
p, q, beta = broadcast_distribution_samples([p, q, beta], size=size)
348350

349351
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1. / beta)) - 1
350352

@@ -847,7 +849,8 @@ def random(self, point=None, size=None):
847849
g = generate_samples(stats.poisson.rvs, theta,
848850
dist_shape=self.shape,
849851
size=size)
850-
return g * (np.random.random(np.squeeze(g.shape)) < psi)
852+
g, psi = broadcast_distribution_samples([g, psi], size=size)
853+
return g * (np.random.random(g.shape) < psi)
851854

852855
def logp(self, value):
853856
psi = self.psi
@@ -939,7 +942,8 @@ def random(self, point=None, size=None):
939942
g = generate_samples(stats.binom.rvs, n, p,
940943
dist_shape=self.shape,
941944
size=size)
942-
return g * (np.random.random(np.squeeze(g.shape)) < psi)
945+
g, psi = broadcast_distribution_samples([g, psi], size=size)
946+
return g * (np.random.random(g.shape) < psi)
943947

944948
def logp(self, value):
945949
psi = self.psi
@@ -1057,7 +1061,8 @@ def random(self, point=None, size=None):
10571061
dist_shape=self.shape,
10581062
size=size)
10591063
g[g == 0] = np.finfo(float).eps # Just in case
1060-
return stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
1064+
g, psi = broadcast_distribution_samples([g, psi], size=size)
1065+
return stats.poisson.rvs(g) * (np.random.random(g.shape) < psi)
10611066

10621067
def logp(self, value):
10631068
alpha = self.alpha

pymc3/distributions/distribution.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,3 +640,30 @@ def generate_samples(generator, *args, **kwargs):
640640
if one_d and samples.shape[-1] == 1:
641641
samples = samples.reshape(samples.shape[:-1])
642642
return np.asarray(samples)
643+
644+
645+
def broadcast_distribution_samples(samples, size=None):
646+
if size is None:
647+
return np.broadcast_arrays(*samples)
648+
_size = to_tuple(size)
649+
try:
650+
broadcasted_samples = np.broadcast_arrays(*samples)
651+
except ValueError:
652+
# Raw samples shapes
653+
p_shapes = [p.shape for p in samples]
654+
# samples shapes without the size prepend
655+
sp_shapes = [s[len(_size):] if _size == s[:len(_size)] else s
656+
for s in p_shapes]
657+
broadcast_shape = np.broadcast(*[np.empty(s) for s in sp_shapes]).shape
658+
broadcasted_samples = []
659+
for param, p_shape, sp_shape in zip(samples, p_shapes, sp_shapes):
660+
if _size == p_shape[:len(_size)]:
661+
slicer_head = [slice(None)] * len(_size)
662+
else:
663+
slicer_head = [np.newaxis] * len(_size)
664+
slicer_tail = ([np.newaxis] * (len(broadcast_shape) -
665+
len(sp_shape)) +
666+
[slice(None)] * len(sp_shape))
667+
broadcasted_samples.append(param[tuple(slicer_head + slicer_tail)])
668+
broadcasted_samples = np.broadcast_arrays(*broadcasted_samples)
669+
return broadcasted_samples

pymc3/tests/test_sampling.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,13 @@ def test_shape_edgecase(self):
467467
x = pm.Normal('x', mu=mu, sd=sd, shape=5)
468468
prior = pm.sample_prior_predictive(10)
469469
assert prior['mu'].shape == (10, 5)
470+
471+
def test_zeroinflatedpoisson(self):
472+
with pm.Model():
473+
theta = pm.Beta('theta', alpha=1, beta=1)
474+
psi = pm.HalfNormal('psi', sd=1)
475+
pm.ZeroInflatedPoisson('suppliers', psi=psi, theta=theta, shape=20)
476+
gen_data = pm.sample_prior_predictive(samples=5000)
477+
assert gen_data['theta'].shape == (5000,)
478+
assert gen_data['psi'].shape == (5000,)
479+
assert gen_data['suppliers'].shape == (5000, 20)

0 commit comments

Comments
 (0)