diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index bf814c9ff5..aced473642 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -18,6 +18,8 @@ avoids pickleing issues on UNIX, and allows us to show a progress bar for all chains. If parallel sampling is interrupted, we now return partial results. +- Add `sample_prior_predictive` which allows for efficient sampling from + the unconditioned model. ### Fixes diff --git a/pymc3/distributions/bound.py b/pymc3/distributions/bound.py index c44684b061..35d489ed2c 100644 --- a/pymc3/distributions/bound.py +++ b/pymc3/distributions/bound.py @@ -54,7 +54,8 @@ def _random(self, lower, upper, point=None, size=None): samples = np.zeros(size, dtype=self.dtype).flatten() i, n = 0, len(samples) while i < len(samples): - sample = self._wrapped.random(point=point, size=n) + sample = np.atleast_1d(self._wrapped.random(point=point, size=n)) + select = sample[np.logical_and(sample >= lower, sample <= upper)] samples[i:(i + len(select))] = select[:] i += len(select) diff --git a/pymc3/distributions/discrete.py b/pymc3/distributions/discrete.py index 58ae48df2b..801386a097 100644 --- a/pymc3/distributions/discrete.py +++ b/pymc3/distributions/discrete.py @@ -1,4 +1,3 @@ -from functools import partial import numpy as np import theano import theano.tensor as tt @@ -7,7 +6,7 @@ from pymc3.util import get_variable_name from .dist_math import bound, factln, binomln, betaln, logpow -from .distribution import Discrete, draw_values, generate_samples, reshape_sampled +from .distribution import Discrete, draw_values, generate_samples from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp @@ -154,13 +153,20 @@ def __init__(self, alpha, beta, n, *args, **kwargs): def _random(self, alpha, beta, n, size=None): size = size or 1 - p = np.atleast_1d(stats.beta.rvs(a=alpha, b=beta, size=np.prod(size))) + p = stats.beta.rvs(a=alpha, b=beta, size=size).flatten() # Sometimes scipy.beta returns nan. Ugh. while np.any(np.isnan(p)): i = np.isnan(p) p[i] = stats.beta.rvs(a=alpha, b=beta, size=np.sum(i)) # Sigh... - _n, _p, _size = np.atleast_1d(n).flatten(), p.flatten(), np.prod(size) + _n, _p, _size = np.atleast_1d(n).flatten(), p.flatten(), p.shape[0] + + quotient, remainder = divmod(_p.shape[0], _n.shape[0]) + if remainder != 0: + raise TypeError('n has a bad size! Was cast to {}, must evenly divide {}'.format( + _n.shape[0], _p.shape[0])) + if quotient != 1: + _n = np.tile(_n, quotient) samples = np.reshape(stats.binom.rvs(n=_n, p=_p, size=_size), size) return samples @@ -186,7 +192,7 @@ def _repr_latex_(self, name=None, dist=None): alpha = dist.alpha beta = dist.beta name = r'\text{%s}' % name - return r'${} \sim \text{{NegativeBinomial}}(\mathit{{alpha}}={},~\mathit{{beta}}={})$'.format(name, + return r'${} \sim \text{{BetaBinomial}}(\mathit{{alpha}}={},~\mathit{{beta}}={})$'.format(name, get_variable_name(alpha), get_variable_name(beta)) @@ -495,7 +501,7 @@ def random(self, point=None, size=None): dist_shape=self.shape, size=size) g[g == 0] = np.finfo(float).eps # Just in case - return reshape_sampled(stats.poisson.rvs(g), size, self.shape) + return np.asarray(stats.poisson.rvs(g)).reshape(g.shape) def logp(self, value): mu = self.mu @@ -700,22 +706,23 @@ def __init__(self, p, *args, **kwargs): self.k = tt.shape(p)[-1].tag.test_value except AttributeError: self.k = tt.shape(p)[-1] - self.p = p = tt.as_tensor_variable(p) + p = tt.as_tensor_variable(p) self.p = (p.T / tt.sum(p, -1)).T self.mode = tt.argmax(p) - def random(self, point=None, size=None): - def random_choice(k, *args, **kwargs): - if len(kwargs['p'].shape) > 1: - return np.asarray( - [np.random.choice(k, p=p) - for p in kwargs['p']] - ) - else: - return np.random.choice(k, *args, **kwargs) + def _random(self, k, p, size=None): + if len(p.shape) > 1: + return np.asarray( + [np.random.choice(k, p=pp, size=size) + for pp in p] + ) + else: + return np.asarray(np.random.choice(k, p=p, size=size)) + def random(self, point=None, size=None): p, k = draw_values([self.p, self.k], point=point, size=size) - return generate_samples(partial(random_choice, np.arange(k)), + return generate_samples(self._random, + k=k, p=p, broadcast_shape=p.shape[:-1] or (1,), dist_shape=self.shape, @@ -849,8 +856,7 @@ def random(self, point=None, size=None): g = generate_samples(stats.poisson.rvs, theta, dist_shape=self.shape, size=size) - sampled = g * (np.random.random(np.squeeze(g.shape)) < psi) - return reshape_sampled(sampled, size, self.shape) + return g * (np.random.random(np.squeeze(g.shape)) < psi) def logp(self, value): psi = self.psi @@ -942,8 +948,7 @@ def random(self, point=None, size=None): g = generate_samples(stats.binom.rvs, n, p, dist_shape=self.shape, size=size) - sampled = g * (np.random.random(np.squeeze(g.shape)) < psi) - return reshape_sampled(sampled, size, self.shape) + return g * (np.random.random(np.squeeze(g.shape)) < psi) def logp(self, value): psi = self.psi @@ -1061,8 +1066,7 @@ def random(self, point=None, size=None): dist_shape=self.shape, size=size) g[g == 0] = np.finfo(float).eps # Just in case - sampled = stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi) - return reshape_sampled(sampled, size, self.shape) + return stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi) def logp(self, value): alpha = self.alpha diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index bdb7e266bd..15e024e23d 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -1,4 +1,6 @@ +import collections import numbers + import numpy as np import theano.tensor as tt from theano import function @@ -254,7 +256,7 @@ def draw_values(params, point=None, size=None): # Init givens and the stack of nodes to try to `_draw_value` from givens = {} - stored = set([]) # Some nodes + stored = set() # Some nodes stack = list(leaf_nodes.values()) # A queue would be more appropriate while stack: next_ = stack.pop(0) @@ -279,13 +281,14 @@ def draw_values(params, point=None, size=None): # The named node's children givens values must also be taken # into account. children = named_nodes_children[next_] - temp_givens = [givens[k] for k in givens.keys() if k in children] + temp_givens = [givens[k] for k in givens if k in children] try: # This may fail for autotransformed RVs, which don't # have the random method givens[next_.name] = (next_, _draw_value(next_, point=point, - givens=temp_givens, size=size)) + givens=temp_givens, + size=size)) stored.add(next_.name) except theano.gof.fg.MissingInputError: # The node failed, so we must add the node's parents to @@ -295,10 +298,31 @@ def draw_values(params, point=None, size=None): if node is not None and node.name not in stored and node not in params]) - values = [] - for param in params: - values.append(_draw_value(param, point=point, givens=givens.values(), size=size)) - return values + + # the below makes sure the graph is evaluated in order + # test_distributions_random::TestDrawValues::test_draw_order fails without it + params = dict(enumerate(params)) # some nodes are not hashable + evaluated = {} + to_eval = set() + missing_inputs = set(params) + while to_eval or missing_inputs: + if to_eval == missing_inputs: + raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval])) + to_eval = set(missing_inputs) + missing_inputs = set() + for param_idx in to_eval: + param = params[param_idx] + if hasattr(param, 'name') and param.name in givens: + evaluated[param_idx] = givens[param.name][1] + else: + try: # might evaluate in a bad order, + evaluated[param_idx] = _draw_value(param, point=point, givens=givens.values(), size=size) + if isinstance(param, collections.Hashable) and named_nodes_parents.get(param): + givens[param.name] = (param, evaluated[param_idx]) + except theano.gof.fg.MissingInputError: + missing_inputs.add(param_idx) + + return [evaluated[j] for j in params] # set the order back @memoize @@ -356,43 +380,26 @@ def _draw_value(param, point=None, givens=None, size=None): return point[param.name] elif hasattr(param, 'random') and param.random is not None: return param.random(point=point, size=size) + elif (hasattr(param, 'distribution') and + hasattr(param.distribution, 'random') and + param.distribution.random is not None): + return param.distribution.random(point=point, size=size) else: if givens: variables, values = list(zip(*givens)) else: variables = values = [] func = _compile_theano_function(param, variables) - return func(*values) + if size and values and not all(var.dshape == val.shape for var, val in zip(variables, values)): + return np.array([func(*v) for v in zip(*values)]) + else: + return func(*values) else: raise ValueError('Unexpected type in draw_value: %s' % type(param)) -def broadcast_shapes(*args): - """Return the shape resulting from broadcasting multiple shapes. - Represents numpy's broadcasting rules. - - Parameters - ---------- - *args : array-like of int - Tuples or arrays or lists representing the shapes of arrays to be broadcast. - - Returns - ------- - Resulting shape or None if broadcasting is not possible. - """ - x = list(np.atleast_1d(args[0])) if args else () - for arg in args[1:]: - y = list(np.atleast_1d(arg)) - if len(x) < len(y): - x, y = y, x - x[-len(y):] = [j if i == 1 else i if j == 1 else i if i == j else 0 - for i, j in zip(x[-len(y):], y)] - if not all(x): - return None - return tuple(x) - - -def infer_shape(shape): +def to_tuple(shape): + """Convert ints, arrays, and Nones to tuples""" try: shape = tuple(shape or ()) except TypeError: # If size is an int @@ -401,27 +408,14 @@ def infer_shape(shape): shape = tuple(shape) return shape - -def reshape_sampled(sampled, size, dist_shape): - dist_shape = infer_shape(dist_shape) - repeat_shape = infer_shape(size) - - if np.size(sampled) == 1 or repeat_shape or dist_shape: - return np.reshape(sampled, repeat_shape + dist_shape) - else: - return sampled - - -def replicate_samples(generator, size, repeats, *args, **kwargs): - n = int(np.prod(repeats)) - if n == 1: - samples = generator(size=size, *args, **kwargs) - else: - samples = np.array([generator(size=size, *args, **kwargs) - for _ in range(n)]) - samples = np.reshape(samples, tuple(repeats) + tuple(size)) - return samples - +def _is_one_d(dist_shape): + if hasattr(dist_shape, 'dshape') and dist_shape.dshape in ((), (0,), (1,)): + return True + elif hasattr(dist_shape, 'shape') and dist_shape.shape in ((), (0,), (1,)): + return True + elif dist_shape == (): + return True + return False def generate_samples(generator, *args, **kwargs): """Generate samples from the distribution of a random variable. @@ -453,42 +447,60 @@ def generate_samples(generator, *args, **kwargs): Any remaining *args and **kwargs are passed on to the generator function. """ dist_shape = kwargs.pop('dist_shape', ()) + one_d = _is_one_d(dist_shape) size = kwargs.pop('size', None) broadcast_shape = kwargs.pop('broadcast_shape', None) - params = args + tuple(kwargs.values()) - - if broadcast_shape is None: - broadcast_shape = broadcast_shapes(*[np.atleast_1d(p).shape for p in params - if not isinstance(p, tuple)]) - if broadcast_shape == (): - broadcast_shape = (1,) + if size is None: + size = 1 args = tuple(p[0] if isinstance(p, tuple) else p for p in args) + for key in kwargs: p = kwargs[key] kwargs[key] = p[0] if isinstance(p, tuple) else p - if np.all(dist_shape[-len(broadcast_shape):] == broadcast_shape): - prefix_shape = tuple(dist_shape[:-len(broadcast_shape)]) - else: - prefix_shape = tuple(dist_shape) - - repeat_shape = infer_shape(size) - - if broadcast_shape == (1,) and prefix_shape == (): - if size is not None: - samples = generator(size=size, *args, **kwargs) + if broadcast_shape is None: + inputs = args + tuple(kwargs.values()) + broadcast_shape = np.broadcast(*inputs).shape # size of generator(size=1) + + dist_shape = to_tuple(dist_shape) + broadcast_shape = to_tuple(broadcast_shape) + size_tup = to_tuple(size) + + # All inputs are scalars, end up size (size_tup, dist_shape) + if broadcast_shape in {(), (0,), (1,)}: + samples = generator(size=size_tup + dist_shape, *args, **kwargs) + # Inputs already have the right shape. Just get the right size. + elif broadcast_shape[-len(dist_shape):] == dist_shape or len(dist_shape) == 0: + if size == 1 or (broadcast_shape == size_tup + dist_shape): + samples = generator(size=broadcast_shape, *args, **kwargs) + elif dist_shape == broadcast_shape: + samples = generator(size=size_tup + dist_shape, *args, **kwargs) else: - samples = generator(size=1, *args, **kwargs) + samples = None + # Args have been broadcast correctly, can just ask for the right shape out + elif dist_shape[-len(broadcast_shape):] == broadcast_shape: + samples = generator(size=size_tup + dist_shape, *args, **kwargs) + # Inputs have the right size, have to manually broadcast to the right dist_shape + elif broadcast_shape[:len(size_tup)] == size_tup: + suffix = broadcast_shape[len(size_tup):] + dist_shape + samples = [generator(*args, **kwargs).reshape(size_tup + (1,)) for _ in range(np.prod(suffix, dtype=int))] + samples = np.hstack(samples).reshape(size_tup + suffix) else: - if size is not None: - samples = replicate_samples(generator, - broadcast_shape, - repeat_shape + prefix_shape, - *args, **kwargs) - else: - samples = replicate_samples(generator, - broadcast_shape, - prefix_shape, - *args, **kwargs) - return reshape_sampled(samples, size, dist_shape) + samples = None + + if samples is None: + raise TypeError('''Attempted to generate values with incompatible shapes: + size: {size} + dist_shape: {dist_shape} + broadcast_shape: {broadcast_shape} + '''.format(size=size, dist_shape=dist_shape, broadcast_shape=broadcast_shape)) + + # reshape samples here + if samples.shape[0] == 1 and size == 1: + if len(samples.shape) > len(dist_shape) and samples.shape[-len(dist_shape):] == dist_shape: + samples = samples.reshape(samples.shape[1:]) + + if one_d and samples.shape[-1] == 1: + samples = samples.reshape(samples.shape[:-1]) + return np.asarray(samples) diff --git a/pymc3/distributions/mixture.py b/pymc3/distributions/mixture.py index 7d9885bf56..86e71cfa3a 100644 --- a/pymc3/distributions/mixture.py +++ b/pymc3/distributions/mixture.py @@ -49,7 +49,7 @@ class Mixture(Distribution): lam = pm.Exponential('lam', lam=1, shape=(2,)) # `shape=(2,)` indicates two mixtures. # As we just need the logp, rather than add a RV to the model, we need to call .dist() - components = pm.Poisson.dist(mu=lam, shape=(2,)) + components = pm.Poisson.dist(mu=lam, shape=(2,)) w = pm.Dirichlet('w', a=np.array([1, 1])) # two mixture component weights. @@ -175,6 +175,8 @@ def random_choice(*args, **kwargs): else: samples = np.squeeze(comp_samples[w_samples]) else: + if w_samples.ndim == 1: + w_samples = np.reshape(np.tile(w_samples, size), (size,) + w_samples.shape) samples = np.zeros((size,)+tuple(distshape)) for i in range(size): w_tmp = w_samples[i, :] diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index d1786009a5..4717f90039 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -409,10 +409,12 @@ class Dirichlet(Continuous): def __init__(self, a, transform=transforms.stick_breaking, *args, **kwargs): - shape = a.shape[-1] + shape = np.atleast_1d(a.shape)[-1] + kwargs.setdefault("shape", shape) super(Dirichlet, self).__init__(transform=transform, *args, **kwargs) + self.size_prefix = tuple(self.shape[:-1]) self.k = tt.as_tensor_variable(shape) self.a = a = tt.as_tensor_variable(a) self.mean = a / tt.sum(a) @@ -421,13 +423,31 @@ def __init__(self, a, transform=transforms.stick_breaking, (a - 1) / tt.sum(a - 1), np.nan) - def random(self, point=None, size=None): - a = draw_values([self.a], point=point, size=size)[0] + def _random(self, a, size=None): + gen = stats.dirichlet.rvs + shape = tuple(np.atleast_1d(self.shape)) + if size[-len(shape):] == shape: + real_size = size[:-len(shape)] + else: + real_size = size + if self.size_prefix: + if real_size and real_size[0] == 1: + real_size = real_size[1:] + self.size_prefix + else: + real_size = real_size + self.size_prefix - def _random(a, size=None): - return stats.dirichlet.rvs(a, None if size == a.shape else size) + if a.ndim == 1: + samples = gen(alpha=a, size=real_size) + else: + unrolled = a.reshape((np.prod(a.shape[:-1]), a.shape[-1])) + samples = np.array([gen(alpha=aa, size=1) for aa in unrolled]) + samples = samples.reshape(a.shape) + return samples - samples = generate_samples(_random, a, + def random(self, point=None, size=None): + a = draw_values([self.a], point=point, size=size)[0] + samples = generate_samples(self._random, + a=a, dist_shape=self.shape, size=size) return samples @@ -492,10 +512,10 @@ def __init__(self, n, p, *args, **kwargs): if len(self.shape) > 1: m = self.shape[-2] - try: - assert n.shape == (m,) - except (AttributeError, AssertionError): - n = n * tt.ones(m) + # try: + # assert n.shape == (m,) + # except (AttributeError, AssertionError): + # n = n * tt.ones(m) self.n = tt.shape_padright(n) self.p = p if p.ndim > 1 else tt.shape_padleft(p) elif n.ndim == 1: @@ -521,27 +541,35 @@ def _random(self, n, p, size=None): # Now, re-normalize all of the values in float64 precision. This is done inside the conditionals if size == p.shape: size = None - if (n.ndim == 0) and (p.ndim == 1): + elif size[-len(p.shape):] == p.shape: + size = size[:len(size) - len(p.shape)] + + n_dim = n.squeeze().ndim + + if (n_dim == 0) and (p.ndim == 1): p = p / p.sum() randnum = np.random.multinomial(n, p.squeeze(), size=size) - elif (n.ndim == 0) and (p.ndim > 1): + elif (n_dim == 0) and (p.ndim > 1): p = p / p.sum(axis=1, keepdims=True) randnum = np.asarray([ np.random.multinomial(n.squeeze(), pp, size=size) for pp in p ]) - elif (n.ndim > 0) and (p.ndim == 1): + randnum = np.moveaxis(randnum, 1, 0) + elif (n_dim > 0) and (p.ndim == 1): p = p / p.sum() randnum = np.asarray([ np.random.multinomial(nn, p.squeeze(), size=size) for nn in n ]) + randnum = np.moveaxis(randnum, 1, 0) else: p = p / p.sum(axis=1, keepdims=True) randnum = np.asarray([ np.random.multinomial(nn, pp, size=size) for (nn, pp) in zip(n, p) ]) + randnum = np.moveaxis(randnum, 1, 0) return randnum.astype(original_dtype) def random(self, point=None, size=None): @@ -1259,15 +1287,31 @@ def _setup_matrices(self, colcov, colchol, coltau, rowcov, rowchol, rowtau): self.colchol_cov = tt.as_tensor_variable(colchol) def random(self, point=None, size=None): - if size is None: - size = list(self.shape) - mu, colchol, rowchol = draw_values( [self.mu, self.colchol_cov, self.rowchol_cov], point=point, size=size) - standard_normal = np.random.standard_normal(size) - return mu + np.matmul(rowchol, np.matmul(standard_normal, colchol.T)) + if size is None: + size = () + if size in (None, ()): + standard_normal = np.random.standard_normal((self.shape[0], colchol.shape[-1])) + samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol.T)) + else: + samples = [] + size = tuple(np.atleast_1d(size)) + if mu.shape == tuple(self.shape): + for _ in range(np.prod(size)): + standard_normal = np.random.standard_normal((self.shape[0], colchol.shape[-1])) + samples.append(mu + np.matmul(rowchol, np.matmul(standard_normal, colchol.T))) + else: + for j in range(np.prod(size)): + standard_normal = np.random.standard_normal((self.shape[0], colchol[j].shape[-1])) + samples.append(mu[j] + + np.matmul(rowchol[j], np.matmul(standard_normal, colchol[j].T))) + samples = np.array(samples).reshape(size + tuple(self.shape)) + return samples + + def _trquaddist(self, value): """Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and @@ -1469,7 +1513,6 @@ def _setup_random(self): elif self._cov_type == 'evd': covs = [] for eig, Q in zip(self.eigs_sep, self.Qs): - # print() cov_i = tt.dot(Q, tt.dot(tt.diag(eig), Q.T)) covs.append(cov_i) cov = kronecker(*covs) diff --git a/pymc3/model.py b/pymc3/model.py index 92d4177384..e37ae25ee9 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -100,7 +100,7 @@ def get_named_nodes_and_relations(graph): is a theano named node, and the corresponding value is the set of theano named nodes that are children of the node. These child relations skip unnamed intermediate nodes. - + """ if graph.name is not None: node_parents = {graph: set()} @@ -1017,7 +1017,7 @@ def check_test_point(self, test_point=None, round_vals=2): if test_point is None: test_point = self.test_point - return Series({RV.name:np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs}, + return Series({RV.name:np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs}, name='Log-probability of test_point') def _repr_latex_(self, name=None, dist=None): @@ -1244,7 +1244,7 @@ def pandas_to_array(data): ret = generator(data) else: ret = np.asarray(data) - return pm.smartfloatX(ret) + return pm.floatX(ret) def as_tensor(data, name, model, distribution): @@ -1457,6 +1457,8 @@ def __init__(self, type=None, owner=None, index=None, name=None, if distribution is not None: self.model = model self.distribution = distribution + self.dshape = tuple(distribution.shape) + self.dsize = int(np.prod(distribution.shape)) transformed_name = get_transformed_name(name, transform) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 29faaf8269..b3b569deea 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -11,11 +11,12 @@ from .backends.base import BaseTrace, MultiTrace from .backends.ndarray import NDArray +from .distributions.distribution import draw_values from .model import modelcontext, Point, all_continuous from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, BinaryGibbsMetropolis, CategoricalGibbsMetropolis, Slice, CompoundStep, arraystep) -from .util import update_start_vals, get_untransformed_name, is_transformed_name +from .util import update_start_vals, get_untransformed_name, is_transformed_name, get_default_varnames from .vartypes import discrete_types from pymc3.step_methods.hmc import quadpotential from pymc3 import plots @@ -25,7 +26,7 @@ import sys sys.setrecursionlimit(10000) -__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_ppc_w', 'init_nuts'] +__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_ppc_w', 'init_nuts', 'sample_prior_predictive'] STEP_METHODS = (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, BinaryGibbsMetropolis, Slice, CategoricalGibbsMetropolis) @@ -1276,6 +1277,49 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None, return {k: np.asarray(v) for k, v in ppc.items()} +def sample_prior_predictive(samples=500, model=None, vars=None, random_seed=None): + """Generate samples from the prior predictive distribution. + + Parameters + ---------- + samples : int + Number of samples from the prior predictive to generate. Defaults to 500. + model : Model (optional if in `with` context) + vars : iterable + Variables for which to compute the posterior predictive samples. + Defaults to `model.named_vars`. + random_seed : int + Seed for the random number generator. + + Returns + ------- + dict + Dictionary with the variables as keys. The values are arrays of prior samples. + """ + model = modelcontext(model) + + if vars is None: + vars = set(model.named_vars.keys()) + + if random_seed is not None: + np.random.seed(random_seed) + names = get_default_varnames(model.named_vars, include_transformed=False) + # draw_values fails with auto-transformed variables. transform them later! + values = draw_values([model[name] for name in names], size=samples) + + data = {k: v for k, v in zip(names, values)} + + prior = {} + for var_name in vars: + if var_name in data: + prior[var_name] = data[var_name] + elif is_transformed_name(var_name): + untransformed = get_untransformed_name(var_name) + if untransformed in data: + prior[var_name] = model[untransformed].transformation.forward_val(data[untransformed]) + return prior + + def init_nuts(init='auto', chains=1, n_init=500000, model=None, random_seed=None, progressbar=True, **kwargs): """Set up the mass matrix initialization for NUTS. diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 4cb3423dd2..2e72487257 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -914,12 +914,12 @@ def test_multinomial_mode(self, p, n): # [[[.25, .25, .25, .25]], (2, 4), [7, 11]], [[[.25, .25, .25, .25], [.25, .25, .25, .25]], (2, 4), 13], - [[[.25, .25, .25, .25], - [.25, .25, .25, .25]], (2, 4), [17, 19]], [[[.25, .25, .25, .25], [.25, .25, .25, .25]], (1, 2, 4), [23, 29]], [[[.25, .25, .25, .25], [.25, .25, .25, .25]], (10, 2, 4), [31, 37]], + [[[.25, .25, .25, .25], + [.25, .25, .25, .25]], (2, 4), [17, 19]], ]) def test_multinomial_random(self, p, shape, n): p = np.asarray(p) diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index fde199b6b0..9b847650b7 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -78,6 +78,23 @@ def test_draw_scalar_parameters(self): npt.assert_almost_equal(mu, 0) npt.assert_almost_equal(tau, 1) + def test_draw_dependencies(self): + with pm.Model(): + x = pm.Normal('x', mu=0., sd=1.) + exp_x = pm.Deterministic('exp_x', pm.math.exp(x)) + + x, exp_x = pm.distributions.draw_values([x, exp_x]) + npt.assert_almost_equal(np.exp(x), exp_x) + + def test_draw_order(self): + with pm.Model(): + x = pm.Normal('x', mu=0., sd=1.) + exp_x = pm.Deterministic('exp_x', pm.math.exp(x)) + + # Need to draw x before drawing log_x + exp_x, x = pm.distributions.draw_values([exp_x, x]) + npt.assert_almost_equal(np.exp(x), exp_x) + def test_draw_point_replacement(self): with pm.Model(): mu = pm.Normal('mu', mu=0., tau=1e-3) @@ -187,12 +204,13 @@ def test_different_shapes_and_sample_sizes(self): s = list(size) except TypeError: s = [size] - s.extend(shape) + if s == [1]: + s = [] + if shape not in ((), (1,)): + s.extend(shape) e = tuple(s) a = self.sample_random_variable(rv, size).shape - expected.append(e) - actual.append(a) - assert expected == actual + assert e == a class TestNormal(BaseTestCases.BaseTestCase): diff --git a/pymc3/tests/test_model_helpers.py b/pymc3/tests/test_model_helpers.py index d8df1d8e4b..2b191bd144 100644 --- a/pymc3/tests/test_model_helpers.py +++ b/pymc3/tests/test_model_helpers.py @@ -68,17 +68,18 @@ def test_pandas_to_array(self): # Check function behavior with Theano graph variable theano_output = func(theano_graph_input) assert isinstance(theano_output, theano.gof.graph.Variable) - assert theano_output.name == input_name + assert theano_output.owner.inputs[0].name == input_name # Check function behavior with generator data generator_output = func(square_generator) + + # Output is wrapped with `pm.floatX`, and this unwraps + wrapped = generator_output.owner.inputs[0] # Make sure the returned object has .set_gen and .set_default methods - assert hasattr(generator_output, "set_gen") - assert hasattr(generator_output, "set_default") + assert hasattr(wrapped, "set_gen") + assert hasattr(wrapped, "set_default") # Make sure the returned object is a Theano TensorVariable - assert isinstance(generator_output, tt.TensorVariable) - - return None + assert isinstance(wrapped, tt.TensorVariable) def test_as_tensor(self): """ diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index a781d88402..0ad6662102 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -6,6 +6,7 @@ except ImportError: import mock +import numpy.testing as npt import pymc3 as pm import theano.tensor as tt from theano import shared @@ -337,3 +338,66 @@ def test_exec_nuts_init(method): assert len(start) == 2 assert isinstance(start[0], dict) assert 'a' in start[0] and 'b_log__' in start[0] + +class TestSampleGenerative(SeededTest): + def test_ignores_observed(self): + observed = np.random.normal(10, 1, size=200) + with pm.Model(): + # Use a prior that's way off to show we're ignoring the observed variables + mu = pm.Normal('mu', mu=-100, sd=1) + positive_mu = pm.Deterministic('positive_mu', np.abs(mu)) + z = -1 - positive_mu + pm.Normal('x_obs', mu=z, sd=1, observed=observed) + prior = pm.sample_prior_predictive() + + assert (prior['mu'] < 90).all() + assert (prior['positive_mu'] > 90).all() + assert (prior['x_obs'] < 90).all() + npt.assert_array_almost_equal(prior['positive_mu'], np.abs(prior['mu']), decimal=4) + + def test_respects_shape(self): + for shape in (2, (2,), (10, 2), (10, 10)): + with pm.Model(): + mu = pm.Gamma('mu', 3, 1, shape=1) + goals = pm.Poisson('goals', mu, shape=shape) + trace = pm.sample_prior_predictive(10) + if shape == 2: # want to test shape as an int + shape = (2,) + assert trace['goals'].shape == (10,) + shape + + def test_multivariate(self): + with pm.Model(): + m = pm.Multinomial('m', n=5, p=np.array([0.25, 0.25, 0.25, 0.25]), shape=4) + trace = pm.sample_prior_predictive(10) + + assert m.random(size=10).shape == (10, 4) + assert trace['m'].shape == (10, 4) + + def test_layers(self): + with pm.Model() as model: + a = pm.Uniform('a', lower=0, upper=1, shape=10) + b = pm.Binomial('b', n=1, p=a, shape=10) + + avg = b.random(size=10000).mean(axis=0) + npt.assert_array_almost_equal(avg, 0.5 * np.ones_like(b), decimal=2) + + def test_transformed(self): + n = 18 + at_bats = 45 * np.ones(n, dtype=int) + hits = np.random.randint(1, 40, size=n, dtype=int) + draws = 50 + + with pm.Model() as model: + phi = pm.Beta('phi', alpha=1., beta=1.) + + kappa_log = pm.Exponential('logkappa', lam=5.) + kappa = pm.Deterministic('kappa', tt.exp(kappa_log)) + + thetas = pm.Beta('thetas', alpha=phi*kappa, beta=(1.0-phi)*kappa, shape=n) + + y = pm.Binomial('y', n=at_bats, p=thetas, shape=n, observed=hits) + gen = pm.sample_prior_predictive(draws) + + assert gen['phi'].shape == (draws,) + assert gen['y'].shape == (draws, n) + assert 'thetas_logodds__' in gen \ No newline at end of file