Skip to content

Commit 889b50e

Browse files
ferrinetwiecki
authored andcommitted
tests passed
1 parent 07a248a commit 889b50e

File tree

3 files changed

+35
-20
lines changed

3 files changed

+35
-20
lines changed

pymc3/distributions/discrete.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from scipy import stats
77

88
from .dist_math import bound, factln, binomln, betaln, logpow
9-
from .distribution import Discrete, draw_values, generate_samples
9+
from .distribution import Discrete, draw_values, generate_samples, reshape_sampled
1010

1111
__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'Poisson',
1212
'NegativeBinomial', 'ConstantDist', 'Constant', 'ZeroInflatedPoisson',
@@ -250,7 +250,7 @@ def random(self, point=None, size=None, repeat=None):
250250
dist_shape=self.shape,
251251
size=size)
252252
g[g == 0] = np.finfo(float).eps # Just in case
253-
return stats.poisson.rvs(g)
253+
return reshape_sampled(stats.poisson.rvs(g), size, self.shape)
254254

255255
def logp(self, value):
256256
mu = self.mu
@@ -441,9 +441,11 @@ def logp(self, value):
441441
c = self.c
442442
return bound(0, tt.eq(value, c))
443443

444+
444445
def ConstantDist(*args, **kwargs):
446+
import warnings
445447
warnings.warn("ConstantDist has been deprecated. In future, use Constant instead.",
446-
DeprecationWarning)
448+
DeprecationWarning)
447449
return Constant(*args, **kwargs)
448450

449451

@@ -489,7 +491,8 @@ def random(self, point=None, size=None, repeat=None):
489491
g = generate_samples(stats.poisson.rvs, theta,
490492
dist_shape=self.shape,
491493
size=size)
492-
return g * (np.random.random(np.squeeze(g.shape)) < psi)
494+
sampled = g * (np.random.random(np.squeeze(g.shape)) < psi)
495+
return reshape_sampled(sampled, size, self.shape)
493496

494497
def logp(self, value):
495498
return tt.switch(value > 0,
@@ -543,7 +546,8 @@ def random(self, point=None, size=None, repeat=None):
543546
dist_shape=self.shape,
544547
size=size)
545548
g[g == 0] = np.finfo(float).eps # Just in case
546-
return stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
549+
sampled = stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
550+
return reshape_sampled(sampled, size, self.shape)
547551

548552
def logp(self, value):
549553
return tt.switch(value > 0,

pymc3/distributions/distribution.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
__all__ = ['DensityDist', 'Distribution', 'Continuous',
1111
'Discrete', 'NoDistribution', 'TensorType', 'draw_values']
1212

13+
1314
class _Unpickling(object):
1415
pass
1516

17+
1618
class Distribution(object):
1719
"""Statistical distribution"""
1820
def __new__(cls, name, *args, **kwargs):
@@ -129,12 +131,10 @@ def __init__(self, logp, shape=(), dtype='float64', testval=0, *args, **kwargs):
129131

130132

131133
class MultivariateContinuous(Continuous):
132-
133134
pass
134135

135136

136137
class MultivariateDiscrete(Discrete):
137-
138138
pass
139139

140140

@@ -265,6 +265,22 @@ def broadcast_shapes(*args):
265265
return tuple(x)
266266

267267

268+
def infer_shape(shape):
269+
try:
270+
shape = tuple(shape or ())
271+
except TypeError: # If size is an int
272+
shape = tuple((shape,))
273+
except ValueError: # If size is np.array
274+
shape = tuple(shape)
275+
return shape
276+
277+
278+
def reshape_sampled(sampled, size, dist_shape):
279+
dist_shape = infer_shape(dist_shape)
280+
repeat_shape = infer_shape(size)
281+
return np.reshape(sampled, repeat_shape + dist_shape)
282+
283+
268284
def replicate_samples(generator, size, repeats, *args, **kwargs):
269285
n = int(np.prod(repeats))
270286
if n == 1:
@@ -326,10 +342,7 @@ def generate_samples(generator, *args, **kwargs):
326342
else:
327343
prefix_shape = tuple(dist_shape)
328344

329-
try:
330-
repeat_shape = tuple(size or ())
331-
except TypeError: # If size is an int
332-
repeat_shape = tuple((size,))
345+
repeat_shape = infer_shape(size)
333346

334347
if broadcast_shape == (1,) and prefix_shape == ():
335348
if size is not None:
@@ -342,13 +355,9 @@ def generate_samples(generator, *args, **kwargs):
342355
broadcast_shape,
343356
repeat_shape + prefix_shape,
344357
*args, **kwargs)
345-
if broadcast_shape == (1,) and not prefix_shape == ():
346-
samples = np.reshape(samples, repeat_shape + prefix_shape)
347358
else:
348359
samples = replicate_samples(generator,
349360
broadcast_shape,
350361
prefix_shape,
351362
*args, **kwargs)
352-
if broadcast_shape == (1,):
353-
samples = np.reshape(samples, prefix_shape)
354-
return samples
363+
return reshape_sampled(samples, size, dist_shape)

pymc3/tests/test_distributions_random.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,10 @@ def test_different_shapes_and_sample_sizes(self):
178178
except TypeError:
179179
s = [size]
180180
s.extend(shape)
181-
expected.append(tuple(s))
182-
actual.append(self.sample_random_variable(rv, size).shape)
181+
e = tuple(s)
182+
a = self.sample_random_variable(rv, size).shape
183+
expected.append(e)
184+
actual.append(a)
183185
self.assertSequenceEqual(expected, actual)
184186

185187

@@ -332,8 +334,8 @@ class TestCategorical(BaseTestCases.BaseTestCase):
332334
distribution = pm.Categorical
333335
params = {'p': np.ones(BaseTestCases.BaseTestCase.shape)}
334336

335-
def get_random_variable(self, shape, with_vector_params=False): # don't transform categories
336-
return super(TestCategorical, self).get_random_variable(shape, with_vector_params=False)
337+
def get_random_variable(self, shape, with_vector_params=False, **kwargs): # don't transform categories
338+
return super(TestCategorical, self).get_random_variable(shape, with_vector_params=False, **kwargs)
337339

338340

339341
@attr('scalar_parameter_samples')

0 commit comments

Comments
 (0)