Skip to content

Commit 9da9e27

Browse files
ColCarrollJunpeng Lao
authored and
Junpeng Lao
committed
Update from comments
1 parent 7185791 commit 9da9e27

File tree

5 files changed

+38
-9
lines changed

5 files changed

+38
-9
lines changed

pymc3/distributions/distribution.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,10 +476,8 @@ def generate_samples(generator, *args, **kwargs):
476476
samples = generator(size=broadcast_shape, *args, **kwargs)
477477
elif dist_shape == broadcast_shape:
478478
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
479-
elif size_tup[-len(broadcast_shape):] != broadcast_shape:
480-
samples = generator(size=size_tup + broadcast_shape, *args, **kwargs)
481479
else:
482-
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
480+
samples = None
483481
# Args have been broadcast correctly, can just ask for the right shape out
484482
elif dist_shape[-len(broadcast_shape):] == broadcast_shape:
485483
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
@@ -489,6 +487,9 @@ def generate_samples(generator, *args, **kwargs):
489487
samples = [generator(*args, **kwargs).reshape(size_tup + (1,)) for _ in range(np.prod(suffix, dtype=int))]
490488
samples = np.hstack(samples).reshape(size_tup + suffix)
491489
else:
490+
samples = None
491+
492+
if samples is None:
492493
raise TypeError('''Attempted to generate values with incompatible shapes:
493494
size: {size}
494495
dist_shape: {dist_shape}

pymc3/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,6 +1457,8 @@ def __init__(self, type=None, owner=None, index=None, name=None,
14571457
if distribution is not None:
14581458
self.model = model
14591459
self.distribution = distribution
1460+
self.dshape = tuple(distribution.shape)
1461+
self.dsize = int(np.prod(distribution.shape))
14601462

14611463
transformed_name = get_transformed_name(name, transform)
14621464

pymc3/sampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,13 +1277,13 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None,
12771277
return {k: np.asarray(v) for k, v in ppc.items()}
12781278

12791279

1280-
def sample_generative(samples=500, model=None, vars=None, random_seed=None):
1281-
"""Generate samples from the generative model.
1280+
def sample_prior_predictive(samples=500, model=None, vars=None, random_seed=None):
1281+
"""Generate samples from the prior_predictive distribution.
12821282
12831283
Parameters
12841284
----------
12851285
samples : int
1286-
Number of samples from the prior to generate. Defaults to 500.
1286+
Number of samples from the prior predictive to generate. Defaults to 500.
12871287
model : Model (optional if in `with` context)
12881288
vars : iterable
12891289
Variables for which to compute the posterior predictive samples.

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,10 +909,15 @@ def test_multinomial_mode(self, p, n):
909909
[[.25, .25, .25, .25], (1, 4), 3],
910910
# 3: expect to fail
911911
# [[.25, .25, .25, .25], (10, 4)],
912+
[[.25, .25, .25, .25], (10, 1, 4), 5],
912913
# 5: expect to fail
913914
# [[[.25, .25, .25, .25]], (2, 4), [7, 11]],
914915
[[[.25, .25, .25, .25],
915916
[.25, .25, .25, .25]], (2, 4), 13],
917+
[[[.25, .25, .25, .25],
918+
[.25, .25, .25, .25]], (1, 2, 4), [23, 29]],
919+
[[[.25, .25, .25, .25],
920+
[.25, .25, .25, .25]], (10, 2, 4), [31, 37]],
916921
[[[.25, .25, .25, .25],
917922
[.25, .25, .25, .25]], (2, 4), [17, 19]],
918923
])

pymc3/tests/test_sampling.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def test_ignores_observed(self):
348348
positive_mu = pm.Deterministic('positive_mu', np.abs(mu))
349349
z = -1 - positive_mu
350350
pm.Normal('x_obs', mu=z, sd=1, observed=observed)
351-
prior = pm.sample_generative()
351+
prior = pm.sample_prior_predictive()
352352

353353
assert (prior['mu'] < 90).all()
354354
assert (prior['positive_mu'] > 90).all()
@@ -360,15 +360,15 @@ def test_respects_shape(self):
360360
with pm.Model():
361361
mu = pm.Gamma('mu', 3, 1, shape=1)
362362
goals = pm.Poisson('goals', mu, shape=shape)
363-
trace = pm.sample_generative(10)
363+
trace = pm.sample_prior_predictive(10)
364364
if shape == 2: # want to test shape as an int
365365
shape = (2,)
366366
assert trace['goals'].shape == (10,) + shape
367367

368368
def test_multivariate(self):
369369
with pm.Model():
370370
m = pm.Multinomial('m', n=5, p=np.array([0.25, 0.25, 0.25, 0.25]), shape=4)
371-
trace = pm.sample_generative(10)
371+
trace = pm.sample_prior_predictive(10)
372372

373373
assert m.random(size=10).shape == (10, 4)
374374
assert trace['m'].shape == (10, 4)
@@ -380,3 +380,24 @@ def test_layers(self):
380380

381381
avg = b.random(size=10000).mean(axis=0)
382382
npt.assert_array_almost_equal(avg, 0.5 * np.ones_like(b), decimal=2)
383+
384+
def test_transformed(self):
385+
n = 18
386+
at_bats = 45 * np.ones(n, dtype=int)
387+
hits = np.random.randint(1, 40, size=n, dtype=int)
388+
draws = 50
389+
390+
with pm.Model() as model:
391+
phi = pm.Beta('phi', alpha=1., beta=1.)
392+
393+
kappa_log = pm.Exponential('logkappa', lam=5.)
394+
kappa = pm.Deterministic('kappa', tt.exp(kappa_log))
395+
396+
thetas = pm.Beta('thetas', alpha=phi*kappa, beta=(1.0-phi)*kappa, shape=n)
397+
398+
y = pm.Binomial('y', n=at_bats, p=thetas, shape=n, observed=hits)
399+
gen = pm.sample_prior_predictive(draws)
400+
401+
assert gen['phi'].shape == (draws,)
402+
assert gen['y'].shape == (draws, n)
403+
assert 'thetas_logodds__' in gen

0 commit comments

Comments
 (0)