Skip to content

Commit 452be39

Browse files
ColCarrollJunpeng Lao
authored and
Junpeng Lao
committed
Add size everywhere (#2984)
* Add size arg everywhere * remove print statement * No size in transforms * Remove test
1 parent 77fc24b commit 452be39

File tree

6 files changed

+43
-59
lines changed

6 files changed

+43
-59
lines changed

pymc3/distributions/bound.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,17 @@ def random(self, point=None, size=None):
6868
if self.lower is None and self.upper is None:
6969
return self._wrapped.random(point=point, size=size)
7070
elif self.lower is not None and self.upper is not None:
71-
lower, upper = draw_values([self.lower, self.upper], point=point)
71+
lower, upper = draw_values([self.lower, self.upper], point=point, size=size)
7272
return generate_samples(self._random, lower, upper, point,
7373
dist_shape=self.shape,
7474
size=size)
7575
elif self.lower is not None:
76-
lower = draw_values([self.lower], point=point)
76+
lower = draw_values([self.lower], point=point, size=size)
7777
return generate_samples(self._random, lower, np.inf, point,
7878
dist_shape=self.shape,
7979
size=size)
8080
else:
81-
upper = draw_values([self.upper], point=point)
81+
upper = draw_values([self.upper], point=point, size=size)
8282
return generate_samples(self._random, -np.inf, upper, point,
8383
dist_shape=self.shape,
8484
size=size)

pymc3/distributions/continuous.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def __init__(self, mu=0, sd=None, tau=None, **kwargs):
305305

306306
def random(self, point=None, size=None):
307307
mu, tau, _ = draw_values([self.mu, self.tau, self.sd],
308-
point=point)
308+
point=point, size=size)
309309
return generate_samples(stats.norm.rvs, loc=mu, scale=tau**-0.5,
310310
dist_shape=self.shape,
311311
size=size)
@@ -406,7 +406,7 @@ def __init__(self, sd=None, tau=None, *args, **kwargs):
406406
assert_negative_support(sd, 'sd', 'HalfNormal')
407407

408408
def random(self, point=None, size=None):
409-
sd = draw_values([self.sd], point=point)[0]
409+
sd = draw_values([self.sd], point=point, size=size)[0]
410410
return generate_samples(stats.halfnorm.rvs, loc=0., scale=sd,
411411
dist_shape=self.shape,
412412
size=size)
@@ -547,7 +547,7 @@ def _random(self, mu, lam, alpha, size=None):
547547

548548
def random(self, point=None, size=None):
549549
mu, lam, alpha = draw_values([self.mu, self.lam, self.alpha],
550-
point=point)
550+
point=point, size=size)
551551
return generate_samples(self._random,
552552
mu, lam, alpha,
553553
dist_shape=self.shape,
@@ -672,7 +672,7 @@ def get_alpha_beta(self, alpha=None, beta=None, mu=None, sd=None):
672672

673673
def random(self, point=None, size=None):
674674
alpha, beta = draw_values([self.alpha, self.beta],
675-
point=point)
675+
point=point, size=size)
676676
return generate_samples(stats.beta.rvs, alpha, beta,
677677
dist_shape=self.shape,
678678
size=size)
@@ -2239,7 +2239,7 @@ def __init__(self, nu=None, sd=None, *args, **kwargs):
22392239

22402240
def random(self, point=None, size=None, repeat=None):
22412241
nu, sd = draw_values([self.nu, self.sd],
2242-
point=point)
2242+
point=point, size=size)
22432243
return generate_samples(stats.rice.rvs, b=nu, scale=sd, loc=0,
22442244
dist_shape=self.shape, size=size)
22452245

pymc3/distributions/discrete.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _random(self, alpha, beta, n, size=None):
166166

167167
def random(self, point=None, size=None):
168168
alpha, beta, n = \
169-
draw_values([self.alpha, self.beta, self.n], point=point)
169+
draw_values([self.alpha, self.beta, self.n], point=point, size=size)
170170
return generate_samples(self._random, alpha=alpha, beta=beta, n=n,
171171
dist_shape=self.shape,
172172
size=size)
@@ -247,7 +247,7 @@ def __init__(self, p=None, logit_p=None, *args, **kwargs):
247247
self.mode = tt.cast(tround(self.p), 'int8')
248248

249249
def random(self, point=None, size=None):
250-
p = draw_values([self.p], point=point)[0]
250+
p = draw_values([self.p], point=point, size=size)[0]
251251
return generate_samples(stats.bernoulli.rvs, p,
252252
dist_shape=self.shape,
253253
size=size)
@@ -343,7 +343,7 @@ def _random(self, q, beta, size=None):
343343
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1. / beta)) - 1
344344

345345
def random(self, point=None, size=None):
346-
q, beta = draw_values([self.q, self.beta], point=point)
346+
q, beta = draw_values([self.q, self.beta], point=point, size=size)
347347

348348
return generate_samples(self._random, q, beta,
349349
dist_shape=self.shape,
@@ -410,7 +410,7 @@ def __init__(self, mu, *args, **kwargs):
410410
self.mode = tt.floor(mu).astype('int32')
411411

412412
def random(self, point=None, size=None):
413-
mu = draw_values([self.mu], point=point)[0]
413+
mu = draw_values([self.mu], point=point, size=size)[0]
414414
return generate_samples(stats.poisson.rvs, mu,
415415
dist_shape=self.shape,
416416
size=size)
@@ -490,7 +490,7 @@ def __init__(self, mu, alpha, *args, **kwargs):
490490
self.mode = tt.floor(mu).astype('int32')
491491

492492
def random(self, point=None, size=None):
493-
mu, alpha = draw_values([self.mu, self.alpha], point=point)
493+
mu, alpha = draw_values([self.mu, self.alpha], point=point, size=size)
494494
g = generate_samples(stats.gamma.rvs, alpha, scale=mu / alpha,
495495
dist_shape=self.shape,
496496
size=size)
@@ -564,7 +564,7 @@ def __init__(self, p, *args, **kwargs):
564564
self.mode = 1
565565

566566
def random(self, point=None, size=None):
567-
p = draw_values([self.p], point=point)[0]
567+
p = draw_values([self.p], point=point, size=size)[0]
568568
return generate_samples(np.random.geometric, p,
569569
dist_shape=self.shape,
570570
size=size)
@@ -636,7 +636,7 @@ def _random(self, lower, upper, size=None):
636636
return samples
637637

638638
def random(self, point=None, size=None):
639-
lower, upper = draw_values([self.lower, self.upper], point=point)
639+
lower, upper = draw_values([self.lower, self.upper], point=point, size=size)
640640
return generate_samples(self._random,
641641
lower, upper,
642642
dist_shape=self.shape,
@@ -714,7 +714,7 @@ def random_choice(k, *args, **kwargs):
714714
else:
715715
return np.random.choice(k, *args, **kwargs)
716716

717-
p, k = draw_values([self.p, self.k], point=point)
717+
p, k = draw_values([self.p, self.k], point=point, size=size)
718718
return generate_samples(partial(random_choice, np.arange(k)),
719719
p=p,
720720
broadcast_shape=p.shape[:-1] or (1,),
@@ -764,7 +764,7 @@ def __init__(self, c, *args, **kwargs):
764764
self.mean = self.median = self.mode = self.c = c = tt.as_tensor_variable(c)
765765

766766
def random(self, point=None, size=None):
767-
c = draw_values([self.c], point=point)[0]
767+
c = draw_values([self.c], point=point, size=size)[0]
768768
dtype = np.array(c).dtype
769769

770770
def _random(c, dtype=dtype, size=None):
@@ -845,7 +845,7 @@ def __init__(self, psi, theta, *args, **kwargs):
845845
self.mode = self.pois.mode
846846

847847
def random(self, point=None, size=None):
848-
theta, psi = draw_values([self.theta, self.psi], point=point)
848+
theta, psi = draw_values([self.theta, self.psi], point=point, size=size)
849849
g = generate_samples(stats.poisson.rvs, theta,
850850
dist_shape=self.shape,
851851
size=size)
@@ -938,7 +938,7 @@ def __init__(self, psi, n, p, *args, **kwargs):
938938
self.mode = self.bin.mode
939939

940940
def random(self, point=None, size=None):
941-
n, p, psi = draw_values([self.n, self.p, self.psi], point=point)
941+
n, p, psi = draw_values([self.n, self.p, self.psi], point=point, size=size)
942942
g = generate_samples(stats.binom.rvs, n, p,
943943
dist_shape=self.shape,
944944
size=size)
@@ -1056,7 +1056,7 @@ def __init__(self, psi, mu, alpha, *args, **kwargs):
10561056

10571057
def random(self, point=None, size=None):
10581058
mu, alpha, psi = draw_values(
1059-
[self.mu, self.alpha, self.psi], point=point)
1059+
[self.mu, self.alpha, self.psi], point=point, size=size)
10601060
g = generate_samples(stats.gamma.rvs, alpha, scale=mu / alpha,
10611061
dist_shape=self.shape,
10621062
size=size)

pymc3/distributions/distribution.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def __init__(self, shape=(), dtype=None, defaults=('median', 'mean', 'mode'),
177177

178178
class DensityDist(Distribution):
179179
"""Distribution based on a given log density function.
180-
181-
A distribution with the passed log density function is created.
180+
181+
A distribution with the passed log density function is created.
182182
Requires a custom random function passed as kwarg `random` to
183183
enable sampling.
184184
@@ -200,7 +200,7 @@ def __init__(self, logp, shape=(), dtype=None, testval=0, random=None, *args, **
200200
shape, dtype, testval, *args, **kwargs)
201201
self.logp = logp
202202
self.rand = random
203-
203+
204204
def random(self, *args, **kwargs):
205205
if self.rand is not None:
206206
return self.rand(*args, **kwargs)
@@ -345,9 +345,7 @@ def _draw_value(param, point=None, givens=None, size=None):
345345
size : int, optional
346346
Number of samples
347347
"""
348-
if isinstance(param, numbers.Number):
349-
return param
350-
elif isinstance(param, np.ndarray):
348+
if isinstance(param, (numbers.Number, np.ndarray)):
351349
return param
352350
elif isinstance(param, tt.TensorConstant):
353351
return param.value
@@ -357,7 +355,7 @@ def _draw_value(param, point=None, givens=None, size=None):
357355
if point and hasattr(param, 'model') and param.name in point:
358356
return point[param.name]
359357
elif hasattr(param, 'random') and param.random is not None:
360-
return param.random(point=point, size=None)
358+
return param.random(point=point, size=size)
361359
else:
362360
if givens:
363361
variables, values = list(zip(*givens))

pymc3/distributions/multivariate.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def random(self, point=None, size=None):
234234
size = [size]
235235

236236
if self._cov_type == 'cov':
237-
mu, cov = draw_values([self.mu, self.cov], point=point)
237+
mu, cov = draw_values([self.mu, self.cov], point=point, size=size)
238238
if mu.shape[-1] != cov.shape[-1]:
239239
raise ValueError("Shapes for mu and cov don't match")
240240

@@ -246,15 +246,15 @@ def random(self, point=None, size=None):
246246
return np.nan * np.zeros(size)
247247
return dist.rvs(size)
248248
elif self._cov_type == 'chol':
249-
mu, chol = draw_values([self.mu, self.chol_cov], point=point)
249+
mu, chol = draw_values([self.mu, self.chol_cov], point=point, size=size)
250250
if mu.shape[-1] != chol[0].shape[-1]:
251251
raise ValueError("Shapes for mu and chol don't match")
252252

253253
size.append(mu.shape[-1])
254254
standard_normal = np.random.standard_normal(size)
255255
return mu + np.dot(standard_normal, chol.T)
256256
else:
257-
mu, tau = draw_values([self.mu, self.tau], point=point)
257+
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
258258
if mu.shape[-1] != tau[0].shape[-1]:
259259
raise ValueError("Shapes for mu and tau don't match")
260260

@@ -338,15 +338,15 @@ def __init__(self, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None,
338338
self.mean = self.median = self.mode = self.mu = self.mu
339339

340340
def random(self, point=None, size=None):
341-
nu, mu = draw_values([self.nu, self.mu], point=point)
341+
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
342342
if self._cov_type == 'cov':
343-
cov, = draw_values([self.cov], point=point)
343+
cov, = draw_values([self.cov], point=point, size=size)
344344
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
345345
elif self._cov_type == 'tau':
346-
tau, = draw_values([self.tau], point=point)
346+
tau, = draw_values([self.tau], point=point, size=size)
347347
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
348348
else:
349-
chol, = draw_values([self.chol_cov], point=point)
349+
chol, = draw_values([self.chol_cov], point=point, size=size)
350350
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)
351351

352352
samples = dist.random(point, size)
@@ -422,7 +422,7 @@ def __init__(self, a, transform=transforms.stick_breaking,
422422
np.nan)
423423

424424
def random(self, point=None, size=None):
425-
a = draw_values([self.a], point=point)[0]
425+
a = draw_values([self.a], point=point, size=size)[0]
426426

427427
def _random(a, size=None):
428428
return stats.dirichlet.rvs(a, None if size == a.shape else size)
@@ -545,7 +545,7 @@ def _random(self, n, p, size=None):
545545
return randnum.astype(original_dtype)
546546

547547
def random(self, point=None, size=None):
548-
n, p = draw_values([self.n, self.p], point=point)
548+
n, p = draw_values([self.n, self.p], point=point, size=size)
549549
samples = generate_samples(self._random, n, p,
550550
dist_shape=self.shape,
551551
size=size)
@@ -679,7 +679,7 @@ def __init__(self, nu, V, *args, **kwargs):
679679
np.nan)
680680

681681
def random(self, point=None, size=None):
682-
nu, V = draw_values([self.nu, self.V], point=point)
682+
nu, V = draw_values([self.nu, self.V], point=point, size=size)
683683
size= 1 if size is None else size
684684
return generate_samples(stats.wishart.rvs, np.asscalar(nu), V,
685685
broadcast_shape=(size,))
@@ -1071,7 +1071,7 @@ def _random(self, n, eta, size=None):
10711071
return C.transpose((1, 2, 0))[np.triu_indices(n, k=1)].T
10721072

10731073
def random(self, point=None, size=None):
1074-
n, eta = draw_values([self.n, self.eta], point=point)
1074+
n, eta = draw_values([self.n, self.eta], point=point, size=size)
10751075
size= 1 if size is None else size
10761076
samples = generate_samples(self._random, n, eta,
10771077
broadcast_shape=(size,))
@@ -1264,8 +1264,8 @@ def random(self, point=None, size=None):
12641264

12651265
mu, colchol, rowchol = draw_values(
12661266
[self.mu, self.colchol_cov, self.rowchol_cov],
1267-
point=point
1268-
)
1267+
point=point,
1268+
size=size)
12691269
standard_normal = np.random.standard_normal(size)
12701270
return mu + np.matmul(rowchol, np.matmul(standard_normal, colchol.T))
12711271

pymc3/tests/test_distributions_random.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,6 @@ def test_random_sample_returns_nd_array(self):
9797
assert isinstance(mu, np.ndarray)
9898
assert isinstance(tau, np.ndarray)
9999

100-
def test_random_sample_returns_correctly(self):
101-
# Based on what we discovered in #GH2909
102-
with pm.Model():
103-
a = pm.Uniform('a', lower=0, upper=1, shape=10)
104-
b = pm.Binomial('b', n=1, p=a, shape=10)
105-
array_of_uniform = a.random(size=10000).mean(axis=0)
106-
array_of_binomial = b.random(size=10000).mean(axis=0)
107-
npt.assert_allclose(array_of_uniform, [0.49886929, 0.49949713, 0.49946077, 0.49922606, 0.49927498, 0.50003914,
108-
0.49980687, 0.50180495, 0.500905, 0.50035121], rtol=1e-2, atol=0)
109-
npt.assert_allclose(array_of_binomial, [0.7232, 0.131 , 0.9457, 0.8279, 0.2911, 0.8686, 0.57 , 0.9184,
110-
0.8177, 0.1625], rtol=1e-2, atol=0)
111-
assert isinstance(array_of_binomial, np.ndarray)
112-
assert isinstance(array_of_uniform, np.ndarray)
113-
114100

115101
class BaseTestCases(object):
116102
class BaseTestCase(SeededTest):
@@ -741,17 +727,17 @@ def test_lkj(self):
741727
for n in [2, 10, 50]:
742728
#pylint: disable=cell-var-from-loop
743729
shape = n*(n-1)//2
744-
730+
745731
def ref_rand(size, eta):
746732
beta = eta - 1 + n/2
747733
return (st.beta.rvs(size=(size, shape), a=beta, b=beta)-.5)*2
748734

749735
class TestedLKJCorr (pm.LKJCorr):
750-
736+
751737
def __init__(self, **kwargs):
752738
kwargs.pop('shape', None)
753739
super(TestedLKJCorr, self).__init__(
754-
n=n,
740+
n=n,
755741
**kwargs
756742
)
757743

@@ -767,12 +753,12 @@ def ref_rand(size, w, mu, sd):
767753

768754
pymc3_random(pm.NormalMixture, {'w': Simplex(2),
769755
'mu': Domain([[.05, 2.5], [-5., 1.]], edges=(None, None)),
770-
'sd': Domain([[1, 1], [1.5, 2.]], edges=(None, None))},
756+
'sd': Domain([[1, 1], [1.5, 2.]], edges=(None, None))},
771757
size=1000,
772758
ref_rand=ref_rand)
773759
pymc3_random(pm.NormalMixture, {'w': Simplex(3),
774760
'mu': Domain([[-5., 1., 2.5]], edges=(None, None)),
775-
'sd': Domain([[1.5, 2., 3.]], edges=(None, None))},
761+
'sd': Domain([[1.5, 2., 3.]], edges=(None, None))},
776762
size=1000,
777763
ref_rand=ref_rand)
778764

0 commit comments

Comments
 (0)