Skip to content

Commit ba639e5

Browse files
author
Junpeng Lao
authored
Merge branch 'master' into sample_SMC
2 parents 4d4abc0 + ea5cbf3 commit ba639e5

12 files changed

+338
-145
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
avoids pickleing issues on UNIX, and allows us to show a progress bar
1919
for all chains. If parallel sampling is interrupted, we now return
2020
partial results.
21+
- Add `sample_prior_predictive` which allows for efficient sampling from
22+
the unconditioned model.
2123

2224
### Fixes
2325

pymc3/distributions/bound.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def _random(self, lower, upper, point=None, size=None):
5454
samples = np.zeros(size, dtype=self.dtype).flatten()
5555
i, n = 0, len(samples)
5656
while i < len(samples):
57-
sample = self._wrapped.random(point=point, size=n)
57+
sample = np.atleast_1d(self._wrapped.random(point=point, size=n))
58+
5859
select = sample[np.logical_and(sample >= lower, sample <= upper)]
5960
samples[i:(i + len(select))] = select[:]
6061
i += len(select)

pymc3/distributions/discrete.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from functools import partial
21
import numpy as np
32
import theano
43
import theano.tensor as tt
@@ -7,7 +6,7 @@
76

87
from pymc3.util import get_variable_name
98
from .dist_math import bound, factln, binomln, betaln, logpow
10-
from .distribution import Discrete, draw_values, generate_samples, reshape_sampled
9+
from .distribution import Discrete, draw_values, generate_samples
1110
from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp
1211

1312

@@ -154,13 +153,20 @@ def __init__(self, alpha, beta, n, *args, **kwargs):
154153

155154
def _random(self, alpha, beta, n, size=None):
156155
size = size or 1
157-
p = np.atleast_1d(stats.beta.rvs(a=alpha, b=beta, size=np.prod(size)))
156+
p = stats.beta.rvs(a=alpha, b=beta, size=size).flatten()
158157
# Sometimes scipy.beta returns nan. Ugh.
159158
while np.any(np.isnan(p)):
160159
i = np.isnan(p)
161160
p[i] = stats.beta.rvs(a=alpha, b=beta, size=np.sum(i))
162161
# Sigh...
163-
_n, _p, _size = np.atleast_1d(n).flatten(), p.flatten(), np.prod(size)
162+
_n, _p, _size = np.atleast_1d(n).flatten(), p.flatten(), p.shape[0]
163+
164+
quotient, remainder = divmod(_p.shape[0], _n.shape[0])
165+
if remainder != 0:
166+
raise TypeError('n has a bad size! Was cast to {}, must evenly divide {}'.format(
167+
_n.shape[0], _p.shape[0]))
168+
if quotient != 1:
169+
_n = np.tile(_n, quotient)
164170
samples = np.reshape(stats.binom.rvs(n=_n, p=_p, size=_size), size)
165171
return samples
166172

@@ -186,7 +192,7 @@ def _repr_latex_(self, name=None, dist=None):
186192
alpha = dist.alpha
187193
beta = dist.beta
188194
name = r'\text{%s}' % name
189-
return r'${} \sim \text{{NegativeBinomial}}(\mathit{{alpha}}={},~\mathit{{beta}}={})$'.format(name,
195+
return r'${} \sim \text{{BetaBinomial}}(\mathit{{alpha}}={},~\mathit{{beta}}={})$'.format(name,
190196
get_variable_name(alpha),
191197
get_variable_name(beta))
192198

@@ -495,7 +501,7 @@ def random(self, point=None, size=None):
495501
dist_shape=self.shape,
496502
size=size)
497503
g[g == 0] = np.finfo(float).eps # Just in case
498-
return reshape_sampled(stats.poisson.rvs(g), size, self.shape)
504+
return np.asarray(stats.poisson.rvs(g)).reshape(g.shape)
499505

500506
def logp(self, value):
501507
mu = self.mu
@@ -700,22 +706,23 @@ def __init__(self, p, *args, **kwargs):
700706
self.k = tt.shape(p)[-1].tag.test_value
701707
except AttributeError:
702708
self.k = tt.shape(p)[-1]
703-
self.p = p = tt.as_tensor_variable(p)
709+
p = tt.as_tensor_variable(p)
704710
self.p = (p.T / tt.sum(p, -1)).T
705711
self.mode = tt.argmax(p)
706712

707-
def random(self, point=None, size=None):
708-
def random_choice(k, *args, **kwargs):
709-
if len(kwargs['p'].shape) > 1:
710-
return np.asarray(
711-
[np.random.choice(k, p=p)
712-
for p in kwargs['p']]
713-
)
714-
else:
715-
return np.random.choice(k, *args, **kwargs)
713+
def _random(self, k, p, size=None):
714+
if len(p.shape) > 1:
715+
return np.asarray(
716+
[np.random.choice(k, p=pp, size=size)
717+
for pp in p]
718+
)
719+
else:
720+
return np.asarray(np.random.choice(k, p=p, size=size))
716721

722+
def random(self, point=None, size=None):
717723
p, k = draw_values([self.p, self.k], point=point, size=size)
718-
return generate_samples(partial(random_choice, np.arange(k)),
724+
return generate_samples(self._random,
725+
k=k,
719726
p=p,
720727
broadcast_shape=p.shape[:-1] or (1,),
721728
dist_shape=self.shape,
@@ -849,8 +856,7 @@ def random(self, point=None, size=None):
849856
g = generate_samples(stats.poisson.rvs, theta,
850857
dist_shape=self.shape,
851858
size=size)
852-
sampled = g * (np.random.random(np.squeeze(g.shape)) < psi)
853-
return reshape_sampled(sampled, size, self.shape)
859+
return g * (np.random.random(np.squeeze(g.shape)) < psi)
854860

855861
def logp(self, value):
856862
psi = self.psi
@@ -942,8 +948,7 @@ def random(self, point=None, size=None):
942948
g = generate_samples(stats.binom.rvs, n, p,
943949
dist_shape=self.shape,
944950
size=size)
945-
sampled = g * (np.random.random(np.squeeze(g.shape)) < psi)
946-
return reshape_sampled(sampled, size, self.shape)
951+
return g * (np.random.random(np.squeeze(g.shape)) < psi)
947952

948953
def logp(self, value):
949954
psi = self.psi
@@ -1061,8 +1066,7 @@ def random(self, point=None, size=None):
10611066
dist_shape=self.shape,
10621067
size=size)
10631068
g[g == 0] = np.finfo(float).eps # Just in case
1064-
sampled = stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
1065-
return reshape_sampled(sampled, size, self.shape)
1069+
return stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
10661070

10671071
def logp(self, value):
10681072
alpha = self.alpha

pymc3/distributions/distribution.py

Lines changed: 96 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import collections
12
import numbers
3+
24
import numpy as np
35
import theano.tensor as tt
46
from theano import function
@@ -254,7 +256,7 @@ def draw_values(params, point=None, size=None):
254256

255257
# Init givens and the stack of nodes to try to `_draw_value` from
256258
givens = {}
257-
stored = set([]) # Some nodes
259+
stored = set() # Some nodes
258260
stack = list(leaf_nodes.values()) # A queue would be more appropriate
259261
while stack:
260262
next_ = stack.pop(0)
@@ -279,13 +281,14 @@ def draw_values(params, point=None, size=None):
279281
# The named node's children givens values must also be taken
280282
# into account.
281283
children = named_nodes_children[next_]
282-
temp_givens = [givens[k] for k in givens.keys() if k in children]
284+
temp_givens = [givens[k] for k in givens if k in children]
283285
try:
284286
# This may fail for autotransformed RVs, which don't
285287
# have the random method
286288
givens[next_.name] = (next_, _draw_value(next_,
287289
point=point,
288-
givens=temp_givens, size=size))
290+
givens=temp_givens,
291+
size=size))
289292
stored.add(next_.name)
290293
except theano.gof.fg.MissingInputError:
291294
# The node failed, so we must add the node's parents to
@@ -295,10 +298,31 @@ def draw_values(params, point=None, size=None):
295298
if node is not None and
296299
node.name not in stored and
297300
node not in params])
298-
values = []
299-
for param in params:
300-
values.append(_draw_value(param, point=point, givens=givens.values(), size=size))
301-
return values
301+
302+
# the below makes sure the graph is evaluated in order
303+
# test_distributions_random::TestDrawValues::test_draw_order fails without it
304+
params = dict(enumerate(params)) # some nodes are not hashable
305+
evaluated = {}
306+
to_eval = set()
307+
missing_inputs = set(params)
308+
while to_eval or missing_inputs:
309+
if to_eval == missing_inputs:
310+
raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval]))
311+
to_eval = set(missing_inputs)
312+
missing_inputs = set()
313+
for param_idx in to_eval:
314+
param = params[param_idx]
315+
if hasattr(param, 'name') and param.name in givens:
316+
evaluated[param_idx] = givens[param.name][1]
317+
else:
318+
try: # might evaluate in a bad order,
319+
evaluated[param_idx] = _draw_value(param, point=point, givens=givens.values(), size=size)
320+
if isinstance(param, collections.Hashable) and named_nodes_parents.get(param):
321+
givens[param.name] = (param, evaluated[param_idx])
322+
except theano.gof.fg.MissingInputError:
323+
missing_inputs.add(param_idx)
324+
325+
return [evaluated[j] for j in params] # set the order back
302326

303327

304328
@memoize
@@ -356,43 +380,26 @@ def _draw_value(param, point=None, givens=None, size=None):
356380
return point[param.name]
357381
elif hasattr(param, 'random') and param.random is not None:
358382
return param.random(point=point, size=size)
383+
elif (hasattr(param, 'distribution') and
384+
hasattr(param.distribution, 'random') and
385+
param.distribution.random is not None):
386+
return param.distribution.random(point=point, size=size)
359387
else:
360388
if givens:
361389
variables, values = list(zip(*givens))
362390
else:
363391
variables = values = []
364392
func = _compile_theano_function(param, variables)
365-
return func(*values)
393+
if size and values and not all(var.dshape == val.shape for var, val in zip(variables, values)):
394+
return np.array([func(*v) for v in zip(*values)])
395+
else:
396+
return func(*values)
366397
else:
367398
raise ValueError('Unexpected type in draw_value: %s' % type(param))
368399

369400

370-
def broadcast_shapes(*args):
371-
"""Return the shape resulting from broadcasting multiple shapes.
372-
Represents numpy's broadcasting rules.
373-
374-
Parameters
375-
----------
376-
*args : array-like of int
377-
Tuples or arrays or lists representing the shapes of arrays to be broadcast.
378-
379-
Returns
380-
-------
381-
Resulting shape or None if broadcasting is not possible.
382-
"""
383-
x = list(np.atleast_1d(args[0])) if args else ()
384-
for arg in args[1:]:
385-
y = list(np.atleast_1d(arg))
386-
if len(x) < len(y):
387-
x, y = y, x
388-
x[-len(y):] = [j if i == 1 else i if j == 1 else i if i == j else 0
389-
for i, j in zip(x[-len(y):], y)]
390-
if not all(x):
391-
return None
392-
return tuple(x)
393-
394-
395-
def infer_shape(shape):
401+
def to_tuple(shape):
402+
"""Convert ints, arrays, and Nones to tuples"""
396403
try:
397404
shape = tuple(shape or ())
398405
except TypeError: # If size is an int
@@ -401,27 +408,14 @@ def infer_shape(shape):
401408
shape = tuple(shape)
402409
return shape
403410

404-
405-
def reshape_sampled(sampled, size, dist_shape):
406-
dist_shape = infer_shape(dist_shape)
407-
repeat_shape = infer_shape(size)
408-
409-
if np.size(sampled) == 1 or repeat_shape or dist_shape:
410-
return np.reshape(sampled, repeat_shape + dist_shape)
411-
else:
412-
return sampled
413-
414-
415-
def replicate_samples(generator, size, repeats, *args, **kwargs):
416-
n = int(np.prod(repeats))
417-
if n == 1:
418-
samples = generator(size=size, *args, **kwargs)
419-
else:
420-
samples = np.array([generator(size=size, *args, **kwargs)
421-
for _ in range(n)])
422-
samples = np.reshape(samples, tuple(repeats) + tuple(size))
423-
return samples
424-
411+
def _is_one_d(dist_shape):
412+
if hasattr(dist_shape, 'dshape') and dist_shape.dshape in ((), (0,), (1,)):
413+
return True
414+
elif hasattr(dist_shape, 'shape') and dist_shape.shape in ((), (0,), (1,)):
415+
return True
416+
elif dist_shape == ():
417+
return True
418+
return False
425419

426420
def generate_samples(generator, *args, **kwargs):
427421
"""Generate samples from the distribution of a random variable.
@@ -453,42 +447,60 @@ def generate_samples(generator, *args, **kwargs):
453447
Any remaining *args and **kwargs are passed on to the generator function.
454448
"""
455449
dist_shape = kwargs.pop('dist_shape', ())
450+
one_d = _is_one_d(dist_shape)
456451
size = kwargs.pop('size', None)
457452
broadcast_shape = kwargs.pop('broadcast_shape', None)
458-
params = args + tuple(kwargs.values())
459-
460-
if broadcast_shape is None:
461-
broadcast_shape = broadcast_shapes(*[np.atleast_1d(p).shape for p in params
462-
if not isinstance(p, tuple)])
463-
if broadcast_shape == ():
464-
broadcast_shape = (1,)
453+
if size is None:
454+
size = 1
465455

466456
args = tuple(p[0] if isinstance(p, tuple) else p for p in args)
457+
467458
for key in kwargs:
468459
p = kwargs[key]
469460
kwargs[key] = p[0] if isinstance(p, tuple) else p
470461

471-
if np.all(dist_shape[-len(broadcast_shape):] == broadcast_shape):
472-
prefix_shape = tuple(dist_shape[:-len(broadcast_shape)])
473-
else:
474-
prefix_shape = tuple(dist_shape)
475-
476-
repeat_shape = infer_shape(size)
477-
478-
if broadcast_shape == (1,) and prefix_shape == ():
479-
if size is not None:
480-
samples = generator(size=size, *args, **kwargs)
462+
if broadcast_shape is None:
463+
inputs = args + tuple(kwargs.values())
464+
broadcast_shape = np.broadcast(*inputs).shape # size of generator(size=1)
465+
466+
dist_shape = to_tuple(dist_shape)
467+
broadcast_shape = to_tuple(broadcast_shape)
468+
size_tup = to_tuple(size)
469+
470+
# All inputs are scalars, end up size (size_tup, dist_shape)
471+
if broadcast_shape in {(), (0,), (1,)}:
472+
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
473+
# Inputs already have the right shape. Just get the right size.
474+
elif broadcast_shape[-len(dist_shape):] == dist_shape or len(dist_shape) == 0:
475+
if size == 1 or (broadcast_shape == size_tup + dist_shape):
476+
samples = generator(size=broadcast_shape, *args, **kwargs)
477+
elif dist_shape == broadcast_shape:
478+
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
481479
else:
482-
samples = generator(size=1, *args, **kwargs)
480+
samples = None
481+
# Args have been broadcast correctly, can just ask for the right shape out
482+
elif dist_shape[-len(broadcast_shape):] == broadcast_shape:
483+
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
484+
# Inputs have the right size, have to manually broadcast to the right dist_shape
485+
elif broadcast_shape[:len(size_tup)] == size_tup:
486+
suffix = broadcast_shape[len(size_tup):] + dist_shape
487+
samples = [generator(*args, **kwargs).reshape(size_tup + (1,)) for _ in range(np.prod(suffix, dtype=int))]
488+
samples = np.hstack(samples).reshape(size_tup + suffix)
483489
else:
484-
if size is not None:
485-
samples = replicate_samples(generator,
486-
broadcast_shape,
487-
repeat_shape + prefix_shape,
488-
*args, **kwargs)
489-
else:
490-
samples = replicate_samples(generator,
491-
broadcast_shape,
492-
prefix_shape,
493-
*args, **kwargs)
494-
return reshape_sampled(samples, size, dist_shape)
490+
samples = None
491+
492+
if samples is None:
493+
raise TypeError('''Attempted to generate values with incompatible shapes:
494+
size: {size}
495+
dist_shape: {dist_shape}
496+
broadcast_shape: {broadcast_shape}
497+
'''.format(size=size, dist_shape=dist_shape, broadcast_shape=broadcast_shape))
498+
499+
# reshape samples here
500+
if samples.shape[0] == 1 and size == 1:
501+
if len(samples.shape) > len(dist_shape) and samples.shape[-len(dist_shape):] == dist_shape:
502+
samples = samples.reshape(samples.shape[1:])
503+
504+
if one_d and samples.shape[-1] == 1:
505+
samples = samples.reshape(samples.shape[:-1])
506+
return np.asarray(samples)

0 commit comments

Comments
 (0)