diff --git a/pymc3/tests/test_variational_inference.py b/pymc3/tests/test_variational_inference.py index 5c3c9c8cb9..3ce98ffc14 100644 --- a/pymc3/tests/test_variational_inference.py +++ b/pymc3/tests/test_variational_inference.py @@ -207,7 +207,7 @@ def cb(*_): data_t.set_value(next(minibatches)) mu_ = Normal('mu', mu=mu0, sd=sd0, testval=0) Normal('x', mu=mu_, sd=sd, observed=data_t, total_size=n) - inf = self.inference() + inf = self.inference(scale_cost_to_minibatch=True) approx = inf.fit(self.NITER * 3, callbacks= [cb, pm.callbacks.CheckParametersConvergence()], obj_n_mc=10, obj_optimizer=self.optimizer) diff --git a/pymc3/theanof.py b/pymc3/theanof.py index 84f60c52c3..5e48adfe5a 100644 --- a/pymc3/theanof.py +++ b/pymc3/theanof.py @@ -23,7 +23,6 @@ 'join_nonshared_inputs', 'make_shared_replacements', 'generator', - 'GradScale', 'set_tt_rng', 'tt_rng'] @@ -417,13 +416,5 @@ def set_tt_rng(new_rng): launch_rng(_tt_rng) -class GradScale(theano.compile.ViewOp): - def __init__(self, multiplier): - self.multiplier = multiplier - - def grad(self, args, g_outs): - return [self.multiplier * g_out for g_out in g_outs] - - def floatX_array(x): return floatX(np.array(x)) diff --git a/pymc3/util.py b/pymc3/util.py index f9edc3b08f..7e3957d37b 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -6,11 +6,13 @@ def get_transformed_name(name, transform): ---------- name : str Name to transform - transform : object + transform : transforms.Transform Should be a subclass of `transforms.Transform` - Returns: - A string to use for the transformed variable + Returns + ------- + str + A string to use for the transformed variable """ return "{}_{}__".format(name, transform.name) @@ -24,8 +26,10 @@ def is_transformed_name(name): name : str Name to check - Returns: - Boolean, whether the string could have been produced by `get_transormed_name` + Returns + ------- + bool + Boolean, whether the string could have been produced by `get_transormed_name` """ return name.endswith('__') and name.count('_') >= 3 @@ -39,8 +43,10 @@ def get_untransformed_name(name): name : str Name to untransform - Returns: - String with untransformed version of the name. + Returns + ------- + str + String with untransformed version of the name. """ if not is_transformed_name(name): raise ValueError(u'{} does not appear to be a transformed name'.format(name)) @@ -57,8 +63,10 @@ def get_default_varnames(var_iterator, include_transformed): include_transformed : boolean Should transformed variable names be included in return value - Returns: - List of variables, possibly filtered + Returns + ------- + list + List of variables, possibly filtered """ if include_transformed: return list(var_iterator) diff --git a/pymc3/variational/approximations.py b/pymc3/variational/approximations.py index 6956ce90f6..4345fbdf6d 100644 --- a/pymc3/variational/approximations.py +++ b/pymc3/variational/approximations.py @@ -28,25 +28,23 @@ class MeanField(Approximation): mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)} Local Vars are used for Autoencoding Variational Bayes See (AEVB; Kingma and Welling, 2014) for details - model : PyMC3 model for inference - start : Point initial mean - cost_part_grad_scale : float or scalar tensor Scaling score part of gradient can be useful near optimum for archiving better convergence properties. Common schedule is 1 at the start and 0 in the end. So slow decay will be ok. See (Sticking the Landing; Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016) for details - + scale_cost_to_minibatch : bool, default False + Scale cost to minibatch instead of full dataset seed : None or int leave None to use package global RandomStream or other valid value to create instance specific one References - ---------- + ---------- Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016 Sticking the Landing: A Simple Reduced-Variance Gradient for ADVI approximateinference.org/accepted/RoederEtAl2016.pdf @@ -109,19 +107,17 @@ class FullRank(Approximation): mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)} Local Vars are used for Autoencoding Variational Bayes See (AEVB; Kingma and Welling, 2014) for details - model : PyMC3 model for inference - start : Point initial mean - cost_part_grad_scale : float or scalar tensor Scaling score part of gradient can be useful near optimum for archiving better convergence properties. Common schedule is 1 at the start and 0 in the end. So slow decay will be ok. See (Sticking the Landing; Geoffrey Roeder, - Yuhuai Wu, David Duvenaud, 2016) for details - + Yuhuai Wu, David Duvenaud, 2016) for details + scale_cost_to_minibatch : bool, default False + Scale cost to minibatch instead of full dataset seed : None or int leave None to use package global RandomStream or other valid value to create instance specific one @@ -133,12 +129,13 @@ class FullRank(Approximation): approximateinference.org/accepted/RoederEtAl2016.pdf """ def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1, + scale_cost_to_minibatch=False, gpu_compat=False, seed=None, **kwargs): super(FullRank, self).__init__( local_rv=local_rv, model=model, cost_part_grad_scale=cost_part_grad_scale, - seed=seed, - **kwargs + scale_cost_to_minibatch=scale_cost_to_minibatch, + seed=seed, **kwargs ) self.gpu_compat = gpu_compat @@ -213,7 +210,7 @@ def from_mean_field(cls, mean_field, gpu_compat=False): """Construct FullRank from MeanField approximation Parameters - ---------- + ---------- mean_field : MeanField approximation to start with @@ -256,9 +253,9 @@ class Empirical(Approximation): mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)} Local Vars are used for Autoencoding Variational Bayes See (AEVB; Kingma and Welling, 2014) for details - + scale_cost_to_minibatch : bool, default False + Scale cost to minibatch instead of full dataset model : PyMC3 model - seed : None or int leave None to use package global RandomStream or other valid value to create instance specific one @@ -270,11 +267,12 @@ class Empirical(Approximation): ... trace = sample(1000, step=step) ... histogram = Empirical(trace[100:]) """ - def __init__(self, trace, local_rv=None, model=None, seed=None, **kwargs): + def __init__(self, trace, local_rv=None, + scale_cost_to_minibatch=False, + model=None, seed=None, **kwargs): super(Empirical, self).__init__( - local_rv=local_rv, model=model, trace=trace, seed=seed, - **kwargs - ) + local_rv=local_rv, scale_cost_to_minibatch=scale_cost_to_minibatch, + model=model, trace=trace, seed=seed, **kwargs) def check_model(self, model, **kwargs): trace = kwargs.get('trace') @@ -352,7 +350,8 @@ def cov(self): return x.T.dot(x) / self.histogram.shape[0] @classmethod - def from_noise(cls, size, jitter=.01, local_rv=None, start=None, model=None, seed=None): + def from_noise(cls, size, jitter=.01, local_rv=None, + start=None, model=None, seed=None, **kwargs): """Initialize Histogram with random noise Parameters @@ -366,12 +365,16 @@ def from_noise(cls, size, jitter=.01, local_rv=None, start=None, model=None, see start : initial point model : pm.Model PyMC3 Model + seed : None or int + leave None to use package global RandomStream or other + valid value to create instance specific one + kwargs : other kwargs passed to init Returns - ------- + ------- Empirical """ - hist = cls(None, local_rv=local_rv, model=model, seed=seed) + hist = cls(None, local_rv=local_rv, model=model, seed=seed, **kwargs) if start is None: start = hist.model.test_point else: @@ -390,7 +393,7 @@ def sample_approx(approx, draws=100, include_transformed=True): """Draw samples from variational posterior. Parameters - ---------- + ---------- approx : Approximation draws : int Number of random samples. @@ -398,7 +401,7 @@ def sample_approx(approx, draws=100, include_transformed=True): If True, transformed variables are also sampled. Default is True. Returns - ------- + ------- trace : pymc3.backends.base.MultiTrace Samples drawn from variational posterior. """ diff --git a/pymc3/variational/inference.py b/pymc3/variational/inference.py index 3c51c272f5..d716b84597 100644 --- a/pymc3/variational/inference.py +++ b/pymc3/variational/inference.py @@ -304,6 +304,8 @@ class ADVI(Inference): 1 at the start and 0 in the end. So slow decay will be ok. See (Sticking the Landing; Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016) for details + scale_cost_to_minibatch : bool, default False + Scale cost to minibatch instead of full dataset seed : None or int leave None to use package global RandomStream or other valid value to create instance specific one @@ -323,11 +325,15 @@ class ADVI(Inference): - Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. stat, 1050, 1. """ - def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1, + def __init__(self, local_rv=None, model=None, + cost_part_grad_scale=1, + scale_cost_to_minibatch=False, seed=None, start=None): super(ADVI, self).__init__( KL, MeanField, None, - local_rv=local_rv, model=model, cost_part_grad_scale=cost_part_grad_scale, + local_rv=local_rv, model=model, + cost_part_grad_scale=cost_part_grad_scale, + scale_cost_to_minibatch=scale_cost_to_minibatch, seed=seed, start=start) @classmethod @@ -372,7 +378,8 @@ class FullRankADVI(Inference): 1 at the start and 0 in the end. So slow decay will be ok. See (Sticking the Landing; Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016) for details - + scale_cost_to_minibatch : bool, default False + Scale cost to minibatch instead of full dataset seed : None or int leave None to use package global RandomStream or other valid value to create instance specific one @@ -392,11 +399,15 @@ class FullRankADVI(Inference): - Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. stat, 1050, 1. """ - def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1, + def __init__(self, local_rv=None, model=None, + cost_part_grad_scale=1, + scale_cost_to_minibatch=False, gpu_compat=False, seed=None, start=None): super(FullRankADVI, self).__init__( KL, FullRank, None, - local_rv=local_rv, model=model, cost_part_grad_scale=cost_part_grad_scale, + local_rv=local_rv, model=model, + cost_part_grad_scale=cost_part_grad_scale, + scale_cost_to_minibatch=scale_cost_to_minibatch, gpu_compat=gpu_compat, seed=seed, start=start) @classmethod @@ -497,6 +508,8 @@ class SVGD(Inference): model : pm.Model kernel : callable kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.)) + scale_cost_to_minibatch : bool, default False + Scale cost to minibatch instead of full dataset start : dict initial point for inference histogram : Empirical @@ -514,10 +527,13 @@ class SVGD(Inference): arXiv:1608.04471 """ def __init__(self, n_particles=100, jitter=.01, model=None, kernel=test_functions.rbf, - start=None, histogram=None, seed=None, local_rv=None): + scale_cost_to_minibatch=False, start=None, histogram=None, + seed=None, local_rv=None): if histogram is None: histogram = Empirical.from_noise( - n_particles, jitter=jitter, start=start, model=model, local_rv=local_rv, seed=seed) + n_particles, jitter=jitter, + scale_cost_to_minibatch=scale_cost_to_minibatch, + start=start, model=model, local_rv=local_rv, seed=seed) super(SVGD, self).__init__( KSD, histogram, kernel, diff --git a/pymc3/variational/opvi.py b/pymc3/variational/opvi.py index f43c812bde..f734c4d8d7 100644 --- a/pymc3/variational/opvi.py +++ b/pymc3/variational/opvi.py @@ -41,7 +41,7 @@ from ..distributions.dist_math import rho2sd, log_normal from ..model import modelcontext, ArrayOrdering, DictToArrayBijection from ..util import get_default_varnames -from ..theanof import tt_rng, memoize, change_flags, GradScale, identity +from ..theanof import tt_rng, memoize, change_flags, identity __all__ = [ @@ -456,8 +456,9 @@ class Approximation(object): archiving better convergence properties. Common schedule is 1 at the start and 0 in the end. So slow decay will be ok. See (Sticking the Landing; Geoffrey Roeder, - Yuhuai Wu, David Duvenaud, 2016) for details - + Yuhuai Wu, David Duvenaud, 2016) for details + scale_cost_to_minibatch : bool, default False + Scale cost to minibatch instead of full dataset seed : None or int leave None to use package global RandomStream or other valid value to create instance specific one @@ -512,8 +513,15 @@ class Approximation(object): initial_dist_name = 'normal' initial_dist_map = 0. - def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1, seed=None, **kwargs): + def __init__(self, local_rv=None, model=None, + cost_part_grad_scale=1, + scale_cost_to_minibatch=False, + seed=None, **kwargs): model = modelcontext(model) + self.scale_cost_to_minibatch = theano.shared(np.int8(0)) + if scale_cost_to_minibatch: + self.scale_cost_to_minibatch.set_value(1) + self.cost_part_grad_scale = pm.floatX(cost_part_grad_scale) self._seed = seed self._rng = tt_rng(seed) self.model = model @@ -536,7 +544,6 @@ def get_transformed(v): self.flat_view = model.flatten( vars=self.local_vars + self.global_vars ) - self.grad_scale_op = GradScale(cost_part_grad_scale) self._setup(**kwargs) self.shared_params = self.create_shared_params(**kwargs) @@ -555,6 +562,8 @@ def seed(self, seed=None): def normalizing_constant(self): t = self.to_flat_input(tt.max([v.scaling for v in self.model.basic_RVs])) t = theano.clone(t, {self.input: tt.zeros(self.total_size)}) + # if not scale_cost_to_minibatch: t=1 + t = tt.switch(self.scale_cost_to_minibatch, t, tt.constant(1, dtype=t.dtype)) return t def _setup(self, **kwargs): @@ -690,7 +699,7 @@ def scale_grad(self, inp): Sticking the Landing: A Simple Reduced-Variance Gradient for ADVI approximateinference.org/accepted/RoederEtAl2016.pdf """ - return self.grad_scale_op(inp) + return theano.gradient.grad_scale(inp, self.cost_part_grad_scale) def to_flat_input(self, node): """ @@ -868,7 +877,7 @@ def points(): def log_q_W_local(self, z): """log_q_W samples over q for local vars Gradient wrt mu, rho in density parametrization - is set to zero to lower variance of ELBO + can be scaled to lower variance of ELBO """ if not self.local_vars: return tt.constant(0) @@ -878,13 +887,9 @@ def log_q_W_local(self, z): logp = log_normal(z[self.local_slc], mu, rho=rho) scaling = [] for var in self.local_vars: - scaling.append(tt.ones(var.dsize)*var.scaling) + scaling.append(tt.repeat(var.scaling, var.dsize)) scaling = tt.concatenate(scaling) - if z.ndim > 1: # pragma: no cover - # rare case when logq(z) is called directly - logp *= scaling[None] - else: - logp *= scaling + logp *= scaling return self.to_flat_input(tt.sum(logp)) def log_q_W_global(self, z): # pragma: no cover