diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 15e024e23d..83b7d14f09 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -383,7 +383,17 @@ def _draw_value(param, point=None, givens=None, size=None): elif (hasattr(param, 'distribution') and hasattr(param.distribution, 'random') and param.distribution.random is not None): - return param.distribution.random(point=point, size=size) + # reset the dist shape for ObservedRV + if hasattr(param, 'observations'): + dist_tmp = param.distribution + try: + distshape = param.observations.shape.eval() + except AttributeError: + distshape = param.observations.shape + dist_tmp.shape = distshape + return dist_tmp.random(point=point, size=size) + else: + return param.distribution.random(point=point, size=size) else: if givens: variables, values = list(zip(*givens)) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 0ad6662102..c85cb52014 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -353,6 +353,7 @@ def test_ignores_observed(self): assert (prior['mu'] < 90).all() assert (prior['positive_mu'] > 90).all() assert (prior['x_obs'] < 90).all() + assert prior['x_obs'].shape == (500, 200) npt.assert_array_almost_equal(prior['positive_mu'], np.abs(prior['mu']), decimal=4) def test_respects_shape(self): @@ -395,9 +396,28 @@ def test_transformed(self): thetas = pm.Beta('thetas', alpha=phi*kappa, beta=(1.0-phi)*kappa, shape=n) - y = pm.Binomial('y', n=at_bats, p=thetas, shape=n, observed=hits) + y = pm.Binomial('y', n=at_bats, p=thetas, observed=hits) gen = pm.sample_prior_predictive(draws) assert gen['phi'].shape == (draws,) assert gen['y'].shape == (draws, n) - assert 'thetas_logodds__' in gen \ No newline at end of file + assert 'thetas_logodds__' in gen + + def test_shared(self): + n1 = 10 + obs = shared(np.random.rand(n1) < .5) + draws = 50 + + with pm.Model() as m: + p = pm.Beta('p', 1., 1.) + y = pm.Bernoulli('y', p, observed=obs) + gen1 = pm.sample_prior_predictive(draws) + + assert gen1['y'].shape == (draws, n1) + + n2 = 20 + obs.set_value(np.random.rand(n2) < .5) + with m: + gen2 = pm.sample_prior_predictive(draws) + + assert gen2['y'].shape == (draws, n2)