1
- import collections
1
+ import six
2
2
import numbers
3
3
4
4
import numpy as np
8
8
from ..memoize import memoize
9
9
from ..model import (
10
10
Model , get_named_nodes_and_relations , FreeRV ,
11
- ObservedRV , MultiObservedRV
11
+ ObservedRV , MultiObservedRV , Context , InitContextMeta
12
12
)
13
13
from ..vartypes import string_types
14
14
@@ -214,6 +214,48 @@ def random(self, *args, **kwargs):
214
214
"Define a custom random method and pass it as kwarg random" )
215
215
216
216
217
+ class _DrawValuesContext (six .with_metaclass (InitContextMeta , Context )):
218
+ """ A context manager class used while drawing values with draw_values
219
+ """
220
+
221
+ def __new__ (cls , * args , ** kwargs ):
222
+ # resolves the parent instance
223
+ instance = super (_DrawValuesContext , cls ).__new__ (cls )
224
+ if cls .get_contexts ():
225
+ potencial_parent = cls .get_contexts ()[- 1 ]
226
+ # We have to make sure that the context is a _DrawValuesContext
227
+ # and not a Model
228
+ if isinstance (potencial_parent , cls ):
229
+ instance ._parent = potencial_parent
230
+ else :
231
+ instance ._parent = None
232
+ else :
233
+ instance ._parent = None
234
+ return instance
235
+
236
+ def __init__ (self ):
237
+ if self .parent is not None :
238
+ # All _DrawValuesContext instances that are in the context of
239
+ # another _DrawValuesContext will share the reference to the
240
+ # drawn_vars dictionary. This means that separate branches
241
+ # in the nested _DrawValuesContext context tree will see the
242
+ # same drawn values
243
+ self .drawn_vars = self .parent .drawn_vars
244
+ else :
245
+ self .drawn_vars = dict ()
246
+
247
+ @property
248
+ def parent (self ):
249
+ return self ._parent
250
+
251
+
252
+ def is_fast_drawable (var ):
253
+ return isinstance (var , (numbers .Number ,
254
+ np .ndarray ,
255
+ tt .TensorConstant ,
256
+ tt .sharedvar .SharedVariable ))
257
+
258
+
217
259
def draw_values (params , point = None , size = None ):
218
260
"""
219
261
Draw (fix) parameter values. Handles a number of cases:
@@ -232,97 +274,134 @@ def draw_values(params, point=None, size=None):
232
274
b) are *RVs with a random method
233
275
234
276
"""
235
- # Distribution parameters may be nodes which have named node-inputs
236
- # specified in the point. Need to find the node-inputs, their
237
- # parents and children to replace them.
238
- leaf_nodes = {}
239
- named_nodes_parents = {}
240
- named_nodes_children = {}
241
- for param in params :
242
- if hasattr (param , 'name' ):
243
- # Get the named nodes under the `param` node
244
- nn , nnp , nnc = get_named_nodes_and_relations (param )
245
- leaf_nodes .update (nn )
246
- # Update the discovered parental relationships
247
- for k in nnp .keys ():
248
- if k not in named_nodes_parents .keys ():
249
- named_nodes_parents [k ] = nnp [k ]
250
- else :
251
- named_nodes_parents [k ].update (nnp [k ])
252
- # Update the discovered child relationships
253
- for k in nnc .keys ():
254
- if k not in named_nodes_children .keys ():
255
- named_nodes_children [k ] = nnc [k ]
256
- else :
257
- named_nodes_children [k ].update (nnc [k ])
258
-
259
- # Init givens and the stack of nodes to try to `_draw_value` from
260
- givens = {}
261
- stored = set () # Some nodes
262
- stack = list (leaf_nodes .values ()) # A queue would be more appropriate
263
- while stack :
264
- next_ = stack .pop (0 )
265
- if next_ in stored :
266
- # If the node already has a givens value, skip it
267
- continue
268
- elif isinstance (next_ , (tt .TensorConstant ,
269
- tt .sharedvar .SharedVariable )):
270
- # If the node is a theano.tensor.TensorConstant or a
271
- # theano.tensor.sharedvar.SharedVariable, its value will be
272
- # available automatically in _compile_theano_function so
273
- # we can skip it. Furthermore, if this node was treated as a
274
- # TensorVariable that should be compiled by theano in
275
- # _compile_theano_function, it would raise a `TypeError:
276
- # ('Constants not allowed in param list', ...)` for
277
- # TensorConstant, and a `TypeError: Cannot use a shared
278
- # variable (...) as explicit input` for SharedVariable.
279
- stored .add (next_ .name )
280
- continue
281
- else :
282
- # If the node does not have a givens value, try to draw it.
283
- # The named node's children givens values must also be taken
284
- # into account.
285
- children = named_nodes_children [next_ ]
286
- temp_givens = [givens [k ] for k in givens if k in children ]
287
- try :
288
- # This may fail for autotransformed RVs, which don't
289
- # have the random method
290
- givens [next_ .name ] = (next_ , _draw_value (next_ ,
291
- point = point ,
292
- givens = temp_givens ,
293
- size = size ))
294
- stored .add (next_ .name )
295
- except theano .gof .fg .MissingInputError :
296
- # The node failed, so we must add the node's parents to
297
- # the stack of nodes to try to draw from. We exclude the
298
- # nodes in the `params` list.
299
- stack .extend ([node for node in named_nodes_parents [next_ ]
300
- if node is not None and
301
- node .name not in stored and
302
- node not in params ])
303
-
304
- # the below makes sure the graph is evaluated in order
305
- # test_distributions_random::TestDrawValues::test_draw_order fails without it
306
- params = dict (enumerate (params )) # some nodes are not hashable
307
- evaluated = {}
308
- to_eval = set ()
309
- missing_inputs = set (params )
310
- while to_eval or missing_inputs :
311
- if to_eval == missing_inputs :
312
- raise ValueError ('Cannot resolve inputs for {}' .format ([str (params [j ]) for j in to_eval ]))
313
- to_eval = set (missing_inputs )
314
- missing_inputs = set ()
315
- for param_idx in to_eval :
316
- param = params [param_idx ]
317
- if hasattr (param , 'name' ) and param .name in givens :
318
- evaluated [param_idx ] = givens [param .name ][1 ]
277
+ # Get fast drawable values (i.e. things in point or numbers, arrays,
278
+ # constants or shares, or things that were already drawn in related
279
+ # contexts)
280
+ if point is None :
281
+ point = {}
282
+ with _DrawValuesContext () as context :
283
+ params = dict (enumerate (params ))
284
+ drawn = context .drawn_vars
285
+ evaluated = {}
286
+ symbolic_params = []
287
+ for i , p in params .items ():
288
+ # If the param is fast drawable, then draw the value immediately
289
+ if is_fast_drawable (p ):
290
+ v = _draw_value (p , point = point , size = size )
291
+ evaluated [i ] = v
292
+ continue
293
+
294
+ name = getattr (p , 'name' , None )
295
+ if p in drawn :
296
+ # param was drawn in related contexts
297
+ v = drawn [p ]
298
+ evaluated [i ] = v
299
+ elif name is not None and name in point :
300
+ # param.name is in point
301
+ v = point [name ]
302
+ evaluated [i ] = drawn [p ] = v
319
303
else :
320
- try : # might evaluate in a bad order,
321
- evaluated [param_idx ] = _draw_value (param , point = point , givens = givens .values (), size = size )
322
- if isinstance (param , collections .Hashable ) and named_nodes_parents .get (param ):
323
- givens [param .name ] = (param , evaluated [param_idx ])
304
+ # param still needs to be drawn
305
+ symbolic_params .append ((i , p ))
306
+
307
+ if not symbolic_params :
308
+ # We only need to enforce the correct order if there are symbolic
309
+ # params that could be drawn in variable order
310
+ return [evaluated [i ] for i in params ]
311
+
312
+ # Distribution parameters may be nodes which have named node-inputs
313
+ # specified in the point. Need to find the node-inputs, their
314
+ # parents and children to replace them.
315
+ leaf_nodes = {}
316
+ named_nodes_parents = {}
317
+ named_nodes_children = {}
318
+ for _ , param in symbolic_params :
319
+ if hasattr (param , 'name' ):
320
+ # Get the named nodes under the `param` node
321
+ nn , nnp , nnc = get_named_nodes_and_relations (param )
322
+ leaf_nodes .update (nn )
323
+ # Update the discovered parental relationships
324
+ for k in nnp .keys ():
325
+ if k not in named_nodes_parents .keys ():
326
+ named_nodes_parents [k ] = nnp [k ]
327
+ else :
328
+ named_nodes_parents [k ].update (nnp [k ])
329
+ # Update the discovered child relationships
330
+ for k in nnc .keys ():
331
+ if k not in named_nodes_children .keys ():
332
+ named_nodes_children [k ] = nnc [k ]
333
+ else :
334
+ named_nodes_children [k ].update (nnc [k ])
335
+
336
+ # Init givens and the stack of nodes to try to `_draw_value` from
337
+ givens = {p .name : (p , v ) for p , v in drawn .items ()
338
+ if getattr (p , 'name' , None ) is not None }
339
+ stack = list (leaf_nodes .values ()) # A queue would be more appropriate
340
+ while stack :
341
+ next_ = stack .pop (0 )
342
+ if next_ in drawn :
343
+ # If the node already has a givens value, skip it
344
+ continue
345
+ elif isinstance (next_ , (tt .TensorConstant ,
346
+ tt .sharedvar .SharedVariable )):
347
+ # If the node is a theano.tensor.TensorConstant or a
348
+ # theano.tensor.sharedvar.SharedVariable, its value will be
349
+ # available automatically in _compile_theano_function so
350
+ # we can skip it. Furthermore, if this node was treated as a
351
+ # TensorVariable that should be compiled by theano in
352
+ # _compile_theano_function, it would raise a `TypeError:
353
+ # ('Constants not allowed in param list', ...)` for
354
+ # TensorConstant, and a `TypeError: Cannot use a shared
355
+ # variable (...) as explicit input` for SharedVariable.
356
+ continue
357
+ else :
358
+ # If the node does not have a givens value, try to draw it.
359
+ # The named node's children givens values must also be taken
360
+ # into account.
361
+ children = named_nodes_children [next_ ]
362
+ temp_givens = [givens [k ] for k in givens if k in children ]
363
+ try :
364
+ # This may fail for autotransformed RVs, which don't
365
+ # have the random method
366
+ value = _draw_value (next_ ,
367
+ point = point ,
368
+ givens = temp_givens ,
369
+ size = size )
370
+ givens [next_ .name ] = (next_ , value )
371
+ drawn [next_ ] = value
324
372
except theano .gof .fg .MissingInputError :
325
- missing_inputs .add (param_idx )
373
+ # The node failed, so we must add the node's parents to
374
+ # the stack of nodes to try to draw from. We exclude the
375
+ # nodes in the `params` list.
376
+ stack .extend ([node for node in named_nodes_parents [next_ ]
377
+ if node is not None and
378
+ node .name not in drawn and
379
+ node not in params ])
380
+
381
+ # the below makes sure the graph is evaluated in order
382
+ # test_distributions_random::TestDrawValues::test_draw_order fails without it
383
+ # The remaining params that must be drawn are all hashable
384
+ to_eval = set ()
385
+ missing_inputs = set ([j for j , p in symbolic_params ])
386
+ while to_eval or missing_inputs :
387
+ if to_eval == missing_inputs :
388
+ raise ValueError ('Cannot resolve inputs for {}' .format ([str (params [j ]) for j in to_eval ]))
389
+ to_eval = set (missing_inputs )
390
+ missing_inputs = set ()
391
+ for param_idx in to_eval :
392
+ param = params [param_idx ]
393
+ if param in drawn :
394
+ evaluated [param_idx ] = drawn [param ]
395
+ else :
396
+ try : # might evaluate in a bad order,
397
+ value = _draw_value (param ,
398
+ point = point ,
399
+ givens = givens .values (),
400
+ size = size )
401
+ evaluated [param_idx ] = drawn [param ] = value
402
+ givens [param .name ] = (param , value )
403
+ except theano .gof .fg .MissingInputError :
404
+ missing_inputs .add (param_idx )
326
405
327
406
return [evaluated [j ] for j in params ] # set the order back
328
407
@@ -400,8 +479,16 @@ def _draw_value(param, point=None, givens=None, size=None):
400
479
# reset shape to account for shape changes
401
480
# with theano.shared inputs
402
481
dist_tmp .shape = np .array ([])
403
- val = dist_tmp .random (point = point , size = None )
404
- dist_tmp .shape = val .shape
482
+ val = np .atleast_1d (dist_tmp .random (point = point ,
483
+ size = None ))
484
+ # Sometimes point may change the size of val but not the
485
+ # distribution's shape
486
+ if point and size is not None :
487
+ temp_size = np .atleast_1d (size )
488
+ if all (val .shape [:len (temp_size )] == temp_size ):
489
+ dist_tmp .shape = val .shape [len (temp_size ):]
490
+ else :
491
+ dist_tmp .shape = val .shape
405
492
return dist_tmp .random (point = point , size = size )
406
493
else :
407
494
return param .distribution .random (point = point , size = size )
@@ -411,10 +498,24 @@ def _draw_value(param, point=None, givens=None, size=None):
411
498
else :
412
499
variables = values = []
413
500
func = _compile_theano_function (param , variables )
414
- if size and values and not all (var .dshape == val .shape for var , val in zip (variables , values )):
415
- return np .array ([func (* v ) for v in zip (* values )])
501
+ if size is not None :
502
+ size = np .atleast_1d (size )
503
+ dshaped_variables = all ((hasattr (var , 'dshape' )
504
+ for var in variables ))
505
+ if (values and dshaped_variables and
506
+ not all (var .dshape == getattr (val , 'shape' , tuple ())
507
+ for var , val in zip (variables , values ))):
508
+ output = np .array ([func (* v ) for v in zip (* values )])
509
+ elif (size is not None and any ((val .ndim > var .ndim )
510
+ for var , val in zip (variables , values ))):
511
+ output = np .array ([func (* v ) for v in zip (* values )])
416
512
else :
417
- return func (* values )
513
+ output = func (* values )
514
+ return output
515
+ print (param ,
516
+ type (param ),
517
+ isinstance (param , tt .TensorVariable ),
518
+ isinstance (param , (tt .TensorVariable , MultiObservedRV )))
418
519
raise ValueError ('Unexpected type in draw_value: %s' % type (param ))
419
520
420
521
@@ -499,6 +600,20 @@ def generate_samples(generator, *args, **kwargs):
499
600
samples = generator (size = broadcast_shape , * args , ** kwargs )
500
601
elif dist_shape == broadcast_shape :
501
602
samples = generator (size = size_tup + dist_shape , * args , ** kwargs )
603
+ elif len (dist_shape ) == 0 and size_tup and broadcast_shape [:len (size_tup )] == size_tup :
604
+ # Input's dist_shape is scalar, but it has size repetitions.
605
+ # So now the size matches but we have to manually broadcast to
606
+ # the right dist_shape
607
+ samples = [generator (* args , ** kwargs )]
608
+ if samples [0 ].shape == broadcast_shape :
609
+ samples = samples [0 ]
610
+ else :
611
+ suffix = broadcast_shape [len (size_tup ):] + dist_shape
612
+ samples .extend ([generator (* args , ** kwargs ).
613
+ reshape (broadcast_shape )[..., np .newaxis ]
614
+ for _ in range (np .prod (suffix ,
615
+ dtype = int ) - 1 )])
616
+ samples = np .hstack (samples ).reshape (size_tup + suffix )
502
617
else :
503
618
samples = None
504
619
# Args have been broadcast correctly, can just ask for the right shape out
@@ -515,9 +630,11 @@ def generate_samples(generator, *args, **kwargs):
515
630
if samples is None :
516
631
raise TypeError ('''Attempted to generate values with incompatible shapes:
517
632
size: {size}
633
+ size_tup: {size_tup}
634
+ broadcast_shape[:len(size_tup)] == size_tup: {test}
518
635
dist_shape: {dist_shape}
519
636
broadcast_shape: {broadcast_shape}
520
- ''' .format (size = size , dist_shape = dist_shape , broadcast_shape = broadcast_shape ))
637
+ ''' .format (size = size , size_tup = size_tup , dist_shape = dist_shape , broadcast_shape = broadcast_shape , test = broadcast_shape [: len ( size_tup )] == size_tup ))
521
638
522
639
# reshape samples here
523
640
if samples .shape [0 ] == 1 and size == 1 :
0 commit comments