Skip to content

Commit 6cefd17

Browse files
committed
Fix for pymc-devs#3210 which uses a completely different approach than PR pymc-devs#3214. It uses a context manager inside draw_values that makes all the values drawn from TensorVariables or MultiObservedRVs available to nested calls of the original call to draw_values. It is partly inspired by how Edward2 approaches the problem of forward sampling. Ed2 tensors fix a _values attribute after they first call sample and then only return that. They can do it because of their functional scheme, where the entire graph is recreated each time the generative function is called. Our object oriented paradigm cannot set a fixed _values, it has to know it is in the context of a single draw_values call. That is why I opted for context managers to store the drawn values.
1 parent cf62eec commit 6cefd17

File tree

4 files changed

+482
-297
lines changed

4 files changed

+482
-297
lines changed

pymc3/distributions/distribution.py

Lines changed: 214 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import collections
1+
import six
22
import numbers
33

44
import numpy as np
@@ -8,7 +8,7 @@
88
from ..memoize import memoize
99
from ..model import (
1010
Model, get_named_nodes_and_relations, FreeRV,
11-
ObservedRV, MultiObservedRV
11+
ObservedRV, MultiObservedRV, Context, InitContextMeta
1212
)
1313
from ..vartypes import string_types
1414

@@ -214,6 +214,48 @@ def random(self, *args, **kwargs):
214214
"Define a custom random method and pass it as kwarg random")
215215

216216

217+
class _DrawValuesContext(six.with_metaclass(InitContextMeta, Context)):
218+
""" A context manager class used while drawing values with draw_values
219+
"""
220+
221+
def __new__(cls, *args, **kwargs):
222+
# resolves the parent instance
223+
instance = super(_DrawValuesContext, cls).__new__(cls)
224+
if cls.get_contexts():
225+
potencial_parent = cls.get_contexts()[-1]
226+
# We have to make sure that the context is a _DrawValuesContext
227+
# and not a Model
228+
if isinstance(potencial_parent, cls):
229+
instance._parent = potencial_parent
230+
else:
231+
instance._parent = None
232+
else:
233+
instance._parent = None
234+
return instance
235+
236+
def __init__(self):
237+
if self.parent is not None:
238+
# All _DrawValuesContext instances that are in the context of
239+
# another _DrawValuesContext will share the reference to the
240+
# drawn_vars dictionary. This means that separate branches
241+
# in the nested _DrawValuesContext context tree will see the
242+
# same drawn values
243+
self.drawn_vars = self.parent.drawn_vars
244+
else:
245+
self.drawn_vars = dict()
246+
247+
@property
248+
def parent(self):
249+
return self._parent
250+
251+
252+
def is_fast_drawable(var):
253+
return isinstance(var, (numbers.Number,
254+
np.ndarray,
255+
tt.TensorConstant,
256+
tt.sharedvar.SharedVariable))
257+
258+
217259
def draw_values(params, point=None, size=None):
218260
"""
219261
Draw (fix) parameter values. Handles a number of cases:
@@ -232,97 +274,134 @@ def draw_values(params, point=None, size=None):
232274
b) are *RVs with a random method
233275
234276
"""
235-
# Distribution parameters may be nodes which have named node-inputs
236-
# specified in the point. Need to find the node-inputs, their
237-
# parents and children to replace them.
238-
leaf_nodes = {}
239-
named_nodes_parents = {}
240-
named_nodes_children = {}
241-
for param in params:
242-
if hasattr(param, 'name'):
243-
# Get the named nodes under the `param` node
244-
nn, nnp, nnc = get_named_nodes_and_relations(param)
245-
leaf_nodes.update(nn)
246-
# Update the discovered parental relationships
247-
for k in nnp.keys():
248-
if k not in named_nodes_parents.keys():
249-
named_nodes_parents[k] = nnp[k]
250-
else:
251-
named_nodes_parents[k].update(nnp[k])
252-
# Update the discovered child relationships
253-
for k in nnc.keys():
254-
if k not in named_nodes_children.keys():
255-
named_nodes_children[k] = nnc[k]
256-
else:
257-
named_nodes_children[k].update(nnc[k])
258-
259-
# Init givens and the stack of nodes to try to `_draw_value` from
260-
givens = {}
261-
stored = set() # Some nodes
262-
stack = list(leaf_nodes.values()) # A queue would be more appropriate
263-
while stack:
264-
next_ = stack.pop(0)
265-
if next_ in stored:
266-
# If the node already has a givens value, skip it
267-
continue
268-
elif isinstance(next_, (tt.TensorConstant,
269-
tt.sharedvar.SharedVariable)):
270-
# If the node is a theano.tensor.TensorConstant or a
271-
# theano.tensor.sharedvar.SharedVariable, its value will be
272-
# available automatically in _compile_theano_function so
273-
# we can skip it. Furthermore, if this node was treated as a
274-
# TensorVariable that should be compiled by theano in
275-
# _compile_theano_function, it would raise a `TypeError:
276-
# ('Constants not allowed in param list', ...)` for
277-
# TensorConstant, and a `TypeError: Cannot use a shared
278-
# variable (...) as explicit input` for SharedVariable.
279-
stored.add(next_.name)
280-
continue
281-
else:
282-
# If the node does not have a givens value, try to draw it.
283-
# The named node's children givens values must also be taken
284-
# into account.
285-
children = named_nodes_children[next_]
286-
temp_givens = [givens[k] for k in givens if k in children]
287-
try:
288-
# This may fail for autotransformed RVs, which don't
289-
# have the random method
290-
givens[next_.name] = (next_, _draw_value(next_,
291-
point=point,
292-
givens=temp_givens,
293-
size=size))
294-
stored.add(next_.name)
295-
except theano.gof.fg.MissingInputError:
296-
# The node failed, so we must add the node's parents to
297-
# the stack of nodes to try to draw from. We exclude the
298-
# nodes in the `params` list.
299-
stack.extend([node for node in named_nodes_parents[next_]
300-
if node is not None and
301-
node.name not in stored and
302-
node not in params])
303-
304-
# the below makes sure the graph is evaluated in order
305-
# test_distributions_random::TestDrawValues::test_draw_order fails without it
306-
params = dict(enumerate(params)) # some nodes are not hashable
307-
evaluated = {}
308-
to_eval = set()
309-
missing_inputs = set(params)
310-
while to_eval or missing_inputs:
311-
if to_eval == missing_inputs:
312-
raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval]))
313-
to_eval = set(missing_inputs)
314-
missing_inputs = set()
315-
for param_idx in to_eval:
316-
param = params[param_idx]
317-
if hasattr(param, 'name') and param.name in givens:
318-
evaluated[param_idx] = givens[param.name][1]
277+
# Get fast drawable values (i.e. things in point or numbers, arrays,
278+
# constants or shares, or things that were already drawn in related
279+
# contexts)
280+
if point is None:
281+
point = {}
282+
with _DrawValuesContext() as context:
283+
params = dict(enumerate(params))
284+
drawn = context.drawn_vars
285+
evaluated = {}
286+
symbolic_params = []
287+
for i, p in params.items():
288+
# If the param is fast drawable, then draw the value immediately
289+
if is_fast_drawable(p):
290+
v = _draw_value(p, point=point, size=size)
291+
evaluated[i] = v
292+
continue
293+
294+
name = getattr(p, 'name', None)
295+
if p in drawn:
296+
# param was drawn in related contexts
297+
v = drawn[p]
298+
evaluated[i] = v
299+
elif name is not None and name in point:
300+
# param.name is in point
301+
v = point[name]
302+
evaluated[i] = drawn[p] = v
319303
else:
320-
try: # might evaluate in a bad order,
321-
evaluated[param_idx] = _draw_value(param, point=point, givens=givens.values(), size=size)
322-
if isinstance(param, collections.Hashable) and named_nodes_parents.get(param):
323-
givens[param.name] = (param, evaluated[param_idx])
304+
# param still needs to be drawn
305+
symbolic_params.append((i, p))
306+
307+
if not symbolic_params:
308+
# We only need to enforce the correct order if there are symbolic
309+
# params that could be drawn in variable order
310+
return [evaluated[i] for i in params]
311+
312+
# Distribution parameters may be nodes which have named node-inputs
313+
# specified in the point. Need to find the node-inputs, their
314+
# parents and children to replace them.
315+
leaf_nodes = {}
316+
named_nodes_parents = {}
317+
named_nodes_children = {}
318+
for _, param in symbolic_params:
319+
if hasattr(param, 'name'):
320+
# Get the named nodes under the `param` node
321+
nn, nnp, nnc = get_named_nodes_and_relations(param)
322+
leaf_nodes.update(nn)
323+
# Update the discovered parental relationships
324+
for k in nnp.keys():
325+
if k not in named_nodes_parents.keys():
326+
named_nodes_parents[k] = nnp[k]
327+
else:
328+
named_nodes_parents[k].update(nnp[k])
329+
# Update the discovered child relationships
330+
for k in nnc.keys():
331+
if k not in named_nodes_children.keys():
332+
named_nodes_children[k] = nnc[k]
333+
else:
334+
named_nodes_children[k].update(nnc[k])
335+
336+
# Init givens and the stack of nodes to try to `_draw_value` from
337+
givens = {p.name: (p, v) for p, v in drawn.items()
338+
if getattr(p, 'name', None) is not None}
339+
stack = list(leaf_nodes.values()) # A queue would be more appropriate
340+
while stack:
341+
next_ = stack.pop(0)
342+
if next_ in drawn:
343+
# If the node already has a givens value, skip it
344+
continue
345+
elif isinstance(next_, (tt.TensorConstant,
346+
tt.sharedvar.SharedVariable)):
347+
# If the node is a theano.tensor.TensorConstant or a
348+
# theano.tensor.sharedvar.SharedVariable, its value will be
349+
# available automatically in _compile_theano_function so
350+
# we can skip it. Furthermore, if this node was treated as a
351+
# TensorVariable that should be compiled by theano in
352+
# _compile_theano_function, it would raise a `TypeError:
353+
# ('Constants not allowed in param list', ...)` for
354+
# TensorConstant, and a `TypeError: Cannot use a shared
355+
# variable (...) as explicit input` for SharedVariable.
356+
continue
357+
else:
358+
# If the node does not have a givens value, try to draw it.
359+
# The named node's children givens values must also be taken
360+
# into account.
361+
children = named_nodes_children[next_]
362+
temp_givens = [givens[k] for k in givens if k in children]
363+
try:
364+
# This may fail for autotransformed RVs, which don't
365+
# have the random method
366+
value = _draw_value(next_,
367+
point=point,
368+
givens=temp_givens,
369+
size=size)
370+
givens[next_.name] = (next_, value)
371+
drawn[next_] = value
324372
except theano.gof.fg.MissingInputError:
325-
missing_inputs.add(param_idx)
373+
# The node failed, so we must add the node's parents to
374+
# the stack of nodes to try to draw from. We exclude the
375+
# nodes in the `params` list.
376+
stack.extend([node for node in named_nodes_parents[next_]
377+
if node is not None and
378+
node.name not in drawn and
379+
node not in params])
380+
381+
# the below makes sure the graph is evaluated in order
382+
# test_distributions_random::TestDrawValues::test_draw_order fails without it
383+
# The remaining params that must be drawn are all hashable
384+
to_eval = set()
385+
missing_inputs = set([j for j, p in symbolic_params])
386+
while to_eval or missing_inputs:
387+
if to_eval == missing_inputs:
388+
raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval]))
389+
to_eval = set(missing_inputs)
390+
missing_inputs = set()
391+
for param_idx in to_eval:
392+
param = params[param_idx]
393+
if param in drawn:
394+
evaluated[param_idx] = drawn[param]
395+
else:
396+
try: # might evaluate in a bad order,
397+
value = _draw_value(param,
398+
point=point,
399+
givens=givens.values(),
400+
size=size)
401+
evaluated[param_idx] = drawn[param] = value
402+
givens[param.name] = (param, value)
403+
except theano.gof.fg.MissingInputError:
404+
missing_inputs.add(param_idx)
326405

327406
return [evaluated[j] for j in params] # set the order back
328407

@@ -400,8 +479,16 @@ def _draw_value(param, point=None, givens=None, size=None):
400479
# reset shape to account for shape changes
401480
# with theano.shared inputs
402481
dist_tmp.shape = np.array([])
403-
val = dist_tmp.random(point=point, size=None)
404-
dist_tmp.shape = val.shape
482+
val = np.atleast_1d(dist_tmp.random(point=point,
483+
size=None))
484+
# Sometimes point may change the size of val but not the
485+
# distribution's shape
486+
if point and size is not None:
487+
temp_size = np.atleast_1d(size)
488+
if all(val.shape[:len(temp_size)] == temp_size):
489+
dist_tmp.shape = val.shape[len(temp_size):]
490+
else:
491+
dist_tmp.shape = val.shape
405492
return dist_tmp.random(point=point, size=size)
406493
else:
407494
return param.distribution.random(point=point, size=size)
@@ -411,10 +498,24 @@ def _draw_value(param, point=None, givens=None, size=None):
411498
else:
412499
variables = values = []
413500
func = _compile_theano_function(param, variables)
414-
if size and values and not all(var.dshape == val.shape for var, val in zip(variables, values)):
415-
return np.array([func(*v) for v in zip(*values)])
501+
if size is not None:
502+
size = np.atleast_1d(size)
503+
dshaped_variables = all((hasattr(var, 'dshape')
504+
for var in variables))
505+
if (values and dshaped_variables and
506+
not all(var.dshape == getattr(val, 'shape', tuple())
507+
for var, val in zip(variables, values))):
508+
output = np.array([func(*v) for v in zip(*values)])
509+
elif (size is not None and any((val.ndim > var.ndim)
510+
for var, val in zip(variables, values))):
511+
output = np.array([func(*v) for v in zip(*values)])
416512
else:
417-
return func(*values)
513+
output = func(*values)
514+
return output
515+
print(param,
516+
type(param),
517+
isinstance(param, tt.TensorVariable),
518+
isinstance(param, (tt.TensorVariable, MultiObservedRV)))
418519
raise ValueError('Unexpected type in draw_value: %s' % type(param))
419520

420521

@@ -499,6 +600,20 @@ def generate_samples(generator, *args, **kwargs):
499600
samples = generator(size=broadcast_shape, *args, **kwargs)
500601
elif dist_shape == broadcast_shape:
501602
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
603+
elif len(dist_shape) == 0 and size_tup and broadcast_shape[:len(size_tup)] == size_tup:
604+
# Input's dist_shape is scalar, but it has size repetitions.
605+
# So now the size matches but we have to manually broadcast to
606+
# the right dist_shape
607+
samples = [generator(*args, **kwargs)]
608+
if samples[0].shape == broadcast_shape:
609+
samples = samples[0]
610+
else:
611+
suffix = broadcast_shape[len(size_tup):] + dist_shape
612+
samples.extend([generator(*args, **kwargs).
613+
reshape(broadcast_shape)[..., np.newaxis]
614+
for _ in range(np.prod(suffix,
615+
dtype=int) - 1)])
616+
samples = np.hstack(samples).reshape(size_tup + suffix)
502617
else:
503618
samples = None
504619
# Args have been broadcast correctly, can just ask for the right shape out
@@ -515,9 +630,11 @@ def generate_samples(generator, *args, **kwargs):
515630
if samples is None:
516631
raise TypeError('''Attempted to generate values with incompatible shapes:
517632
size: {size}
633+
size_tup: {size_tup}
634+
broadcast_shape[:len(size_tup)] == size_tup: {test}
518635
dist_shape: {dist_shape}
519636
broadcast_shape: {broadcast_shape}
520-
'''.format(size=size, dist_shape=dist_shape, broadcast_shape=broadcast_shape))
637+
'''.format(size=size, size_tup=size_tup, dist_shape=dist_shape, broadcast_shape=broadcast_shape, test=broadcast_shape[:len(size_tup)] == size_tup))
521638

522639
# reshape samples here
523640
if samples.shape[0] == 1 and size == 1:

pymc3/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,16 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
13861386
self.distribution = distribution
13871387
self.scaling = _get_scaling(total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim)
13881388

1389+
# Make hashable by id for draw_values
1390+
def __hash__(self):
1391+
return id(self)
1392+
1393+
def __eq__(self, other):
1394+
return self.id == other.id
1395+
1396+
def __ne__(self, other):
1397+
return not self == other
1398+
13891399

13901400
def _walk_up_rv(rv):
13911401
"""Walk up theano graph to get inputs for deterministic RV."""

0 commit comments

Comments
 (0)