Skip to content

Commit 983d444

Browse files
committed
Add ValueGradFunction to model
1 parent 8fb942c commit 983d444

File tree

6 files changed

+304
-17
lines changed

6 files changed

+304
-17
lines changed

pymc3/blocking.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,28 @@ class ArrayOrdering(object):
2323

2424
def __init__(self, vars):
2525
self.vmap = []
26-
dim = 0
26+
self._by_name = {}
27+
size = 0
2728

2829
for var in vars:
29-
slc = slice(dim, dim + var.dsize)
30-
self.vmap.append(VarMap(str(var), slc, var.dshape, var.dtype))
31-
dim += var.dsize
30+
name = var.name
31+
if name is None:
32+
raise ValueError('Unnamed variable in ArrayOrdering.')
33+
if name in self._by_name:
34+
raise ValueError('Name of variable not unique: %s.' % name)
35+
if not hasattr(var, 'dshape') or not hasattr(var, 'dsize'):
36+
raise ValueError('Shape of variable not known %s' % name)
37+
38+
slc = slice(size, size + var.dsize)
39+
varmap = VarMap(name, slc, var.dshape, var.dtype)
40+
self.vmap.append(varmap)
41+
self._by_name[name] = varmap
42+
size += var.dsize
43+
44+
self.size = size
3245

33-
self.dimensions = dim
46+
def __getitem__(self, key):
47+
return self._by_name[key]
3448

3549

3650
class DictToArrayBijection(object):
@@ -58,7 +72,7 @@ def map(self, dpt):
5872
----------
5973
dpt : dict
6074
"""
61-
apt = np.empty(self.ordering.dimensions, dtype=self.array_dtype)
75+
apt = np.empty(self.ordering.size, dtype=self.array_dtype)
6276
for var, slc, _, _ in self.ordering.vmap:
6377
apt[slc] = dpt[var].ravel()
6478
return apt
@@ -125,7 +139,7 @@ def __init__(self, list_arrays, intype='numpy'):
125139
dim += array.size
126140
count += 1
127141

128-
self.dimensions = dim
142+
self.size = dim
129143

130144

131145
class ListToArrayBijection(object):
@@ -158,7 +172,7 @@ def fmap(self, list_arrays):
158172
single array comprising all the input arrays
159173
"""
160174

161-
array = np.empty(self.ordering.dimensions)
175+
array = np.empty(self.ordering.size)
162176
for list_ind, slc, _, _, _ in self.ordering.vmap:
163177
array[slc] = list_arrays[list_ind].ravel()
164178
return array

pymc3/model.py

Lines changed: 174 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,168 @@ def tree_contains(self, item):
277277
return dict.__contains__(self, item)
278278

279279

280+
class ValueGradFunction(object):
281+
"""Create a theano function that computes a value and its gradient.
282+
283+
Parameters
284+
----------
285+
cost : theano variable
286+
The value that we compute with its gradient.
287+
grad_vars : list of named theano variables or None
288+
The arguments with respect to which the gradient is computed.
289+
extra_args : list of named theano variables or None
290+
Other arguments of the function that are assumed constant. They
291+
are stored in shared variables and can be set using
292+
`set_extra_values`.
293+
dtype : str, default=theano.config.floatX
294+
The dtype of the arrays.
295+
casting : {'no', 'equiv', 'save', 'same_kind', 'unsafe'}, default='no'
296+
Casting rule for casting `grad_args` to the array dtype.
297+
See `numpy.can_cast` for a description of the options.
298+
Keep in mind that we cast the variables to the array *and*
299+
back from the array dtype to the variable dtype.
300+
kwargs
301+
Extra arguments are passed on to `theano.function`.
302+
303+
Attributes
304+
----------
305+
size : int
306+
The number of elements in the parameter array.
307+
profile : theano profiling object or None
308+
The profiling object of the theano function that computes value and
309+
gradient. This is None unless `profile=True` was set in the
310+
kwargs.
311+
"""
312+
def __init__(self, cost, grad_vars, extra_vars=None, dtype=None,
313+
casting='no', **kwargs):
314+
if extra_vars is None:
315+
extra_vars = []
316+
317+
names = [arg.name for arg in grad_vars + extra_vars]
318+
if any(name is None for name in names):
319+
raise ValueError('Arguments must be named.')
320+
if len(set(names)) != len(names):
321+
raise ValueError('Names of the arguments are not unique.')
322+
323+
if cost.ndim > 0:
324+
raise ValueError('Cost must be a scalar.')
325+
326+
self._grad_vars = grad_vars
327+
self._extra_vars = extra_vars
328+
self._extra_var_names = set(var.name for var in extra_vars)
329+
self._cost = cost
330+
self._ordering = ArrayOrdering(grad_vars)
331+
self.size = self._ordering.size
332+
self._extra_are_set = False
333+
if dtype is None:
334+
dtype = theano.config.floatX
335+
self._dtype = dtype
336+
for var in self._grad_vars:
337+
if not np.can_cast(var.dtype, self._dtype, casting):
338+
raise TypeError('Invalid dtype for variable %s. Can not '
339+
'cast to %s with casting rule %s.'
340+
% (var.name, self._dtype, casting))
341+
342+
givens = []
343+
self._extra_vars_shared = {}
344+
for var in extra_vars:
345+
shared = theano.shared(var.tag.test_value, var.name + '_shared__')
346+
self._extra_vars_shared[var.name] = shared
347+
givens.append((var, shared))
348+
349+
self._vars_joined, self._cost_joined = self._build_joined(
350+
self._cost, grad_vars, self._ordering.vmap)
351+
352+
grad = tt.grad(self._cost_joined, self._vars_joined)
353+
354+
inputs = [self._vars_joined]
355+
356+
self._theano_function = theano.function(
357+
inputs, [self._cost_joined, grad], givens=givens, **kwargs)
358+
359+
def set_extra_values(self, extra_vars):
360+
self._extra_are_set = True
361+
for var in self._extra_vars:
362+
self._extra_vars_shared[var.name].set_value(extra_vars[var.name])
363+
364+
def get_extra_values(self):
365+
if not self._extra_are_set:
366+
raise ValueError('Extra values are not set.')
367+
368+
return {var.name: self._extra_vars_shared[var.name].get_value()
369+
for var in self._extra_vars}
370+
371+
def __call__(self, array, grad_out=None, extra_vars=None):
372+
if extra_vars is not None:
373+
self.set_extra_values(extra_vars)
374+
375+
if not self._extra_are_set:
376+
raise ValueError('Extra values are not set.')
377+
378+
if array.shape != (self.size,):
379+
raise ValueError('Invalid shape for array. Must be %s but is %s.'
380+
% ((self.size,), array.shape))
381+
382+
if grad_out is None:
383+
out = np.empty_like(array)
384+
else:
385+
out = grad_out
386+
387+
logp, dlogp = self._theano_function(array)
388+
if grad_out is None:
389+
return logp, dlogp
390+
else:
391+
out[...] = dlogp
392+
return logp
393+
394+
@property
395+
def profile(self):
396+
"""Profiling information of the underlying theano function."""
397+
return self._theano_function.profile
398+
399+
def dict_to_array(self, point):
400+
"""Convert a dictionary with values for grad_vars to an array."""
401+
array = np.empty(self.size, dtype=self._dtype)
402+
for varmap in self._ordering.vmap:
403+
array[varmap.slc] = point[varmap.var].ravel().astype(self._dtype)
404+
return array
405+
406+
def array_to_dict(self, array):
407+
"""Convert an array to a dictionary containing the grad_vars."""
408+
if array.shape != (self.size,):
409+
raise ValueError('Array should have shape (%s,) but has %s'
410+
% (self.size, array.shape))
411+
if array.dtype != self._dtype:
412+
raise ValueError('Array has invalid dtype. Should be %s but is %s'
413+
% (self._dtype, self._dtype))
414+
point = {}
415+
for varmap in self._ordering.vmap:
416+
data = array[varmap.slc].reshape(varmap.shp)
417+
point[varmap.var] = data.astype(varmap.dtyp)
418+
419+
return point
420+
421+
def array_to_full_dict(self, array):
422+
"""Convert an array to a dictionary with grad_vars and extra_vars."""
423+
point = self.array_to_dict(array)
424+
for name, var in self._extra_vars_shared.items():
425+
point[name] = var.get_value()
426+
return point
427+
428+
def _build_joined(self, cost, args, vmap):
429+
args_joined = tt.vector('__args_joined')
430+
dtype = theano.config.floatX
431+
args_joined.tag.test_value = np.zeros(self.size, dtype=dtype)
432+
433+
joined_slices = {}
434+
for vmap in vmap:
435+
sliced = args_joined[vmap.slc].reshape(vmap.shp)
436+
joined_slices[vmap.var] = sliced
437+
438+
replace = {var: joined_slices[var.name] for var in args}
439+
return args_joined, theano.clone(cost, replace=replace)
440+
441+
280442
class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
281443
"""Encapsulates the variables and likelihood factors of a model.
282444
@@ -419,7 +581,6 @@ def bijection(self):
419581
return bij
420582

421583
@property
422-
@memoize
423584
def dict_to_array(self):
424585
return self.bijection.map
425586

@@ -428,18 +589,27 @@ def ndim(self):
428589
return sum(var.dsize for var in self.free_RVs)
429590

430591
@property
431-
@memoize
432592
def logp_array(self):
433593
return self.bijection.mapf(self.fastlogp)
434594

435595
@property
436-
@memoize
437596
def dlogp_array(self):
438597
vars = inputvars(self.cont_vars)
439598
return self.bijection.mapf(self.fastdlogp(vars))
440599

600+
def logp_dlogp_function(self, grad_vars=None, **kwargs):
601+
if grad_vars is None:
602+
grad_vars = list(typefilter(self.free_RVs, continuous_types))
603+
else:
604+
for var in grad_vars:
605+
if var.dtype not in continuous_types:
606+
raise ValueError("Can only compute the gradient of "
607+
"continuous types: %s" % var)
608+
varnames = [var.name for var in grad_vars]
609+
extra_vars = [var for var in self.free_RVs if var.name not in varnames]
610+
return ValueGradFunction(self.logpt, grad_vars, extra_vars, **kwargs)
611+
441612
@property
442-
@memoize
443613
def logpt(self):
444614
"""Theano scalar of log-probability of the model"""
445615
with self:
@@ -595,7 +765,6 @@ def __getitem__(self, key):
595765
except KeyError:
596766
raise e
597767

598-
@memoize
599768
def makefn(self, outs, mode=None, *args, **kwargs):
600769
"""Compiles a Theano function which returns `outs` and takes the variable
601770
ancestors of `outs` as inputs.

pymc3/step_methods/arraystep.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,30 @@ def step(self, point):
156156
return bij.rmap(apoint)
157157

158158

159+
class GradientSharedStep(BlockedStep):
160+
def __init__(self, vars, model=None, blocked=True,
161+
gpu_ctx=None, **theano_kwargs):
162+
model = modelcontext(model)
163+
self.vars = vars
164+
self.blocked = blocked
165+
166+
self._logp_dlogp_func = model.logp_dlogp_function(
167+
vars, **theano_kwargs)
168+
169+
def step(self, point):
170+
self._logp_dlogp_func.set_extra_values(point)
171+
array = self._logp_dlogp_func.dict_to_array(point)
172+
173+
if self.generates_stats:
174+
apoint, stats = self.astep(array)
175+
point = self._logp_dlogp_func.array_to_dict(apoint)
176+
return point, stats
177+
else:
178+
apoint = self.astep(array)
179+
point = self._logp_dlogp_func.array_to_dict(apoint)
180+
return point
181+
182+
159183
def metrop_select(mr, q, q0):
160184
"""Perform rejection/acceptance step for Metropolis class samplers.
161185

pymc3/step_methods/smc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def select_end_points(self, mtrace):
324324
likelihoods : :class:`numpy.ndarray`
325325
Array of likelihoods of the trace end-points
326326
"""
327-
array_population = np.zeros((self.n_chains, self.ordering.dimensions))
327+
array_population = np.zeros((self.n_chains, self.ordering.size))
328328
n_steps = len(mtrace)
329329

330330
# collect end points of each chain and put into array
@@ -357,7 +357,7 @@ def get_chain_previous_lpoint(self, mtrace):
357357
chain_previous_lpoint : list
358358
all unobservedRV values, including dataset likelihoods
359359
"""
360-
array_population = np.zeros((self.n_chains, self.lordering.dimensions))
360+
array_population = np.zeros((self.n_chains, self.lordering.size))
361361
n_steps = len(mtrace)
362362
for _, slc, shp, _, var in self.lordering.vmap:
363363
slc_population = mtrace.get_values(varname=var, burn=n_steps - 1, combine=True)

0 commit comments

Comments
 (0)