1
+ import collections
1
2
import numbers
3
+
2
4
import numpy as np
3
5
import theano .tensor as tt
4
6
from theano import function
@@ -254,7 +256,7 @@ def draw_values(params, point=None, size=None):
254
256
255
257
# Init givens and the stack of nodes to try to `_draw_value` from
256
258
givens = {}
257
- stored = set ([] ) # Some nodes
259
+ stored = set () # Some nodes
258
260
stack = list (leaf_nodes .values ()) # A queue would be more appropriate
259
261
while stack :
260
262
next_ = stack .pop (0 )
@@ -279,13 +281,14 @@ def draw_values(params, point=None, size=None):
279
281
# The named node's children givens values must also be taken
280
282
# into account.
281
283
children = named_nodes_children [next_ ]
282
- temp_givens = [givens [k ] for k in givens . keys () if k in children ]
284
+ temp_givens = [givens [k ] for k in givens if k in children ]
283
285
try :
284
286
# This may fail for autotransformed RVs, which don't
285
287
# have the random method
286
288
givens [next_ .name ] = (next_ , _draw_value (next_ ,
287
289
point = point ,
288
- givens = temp_givens , size = size ))
290
+ givens = temp_givens ,
291
+ size = size ))
289
292
stored .add (next_ .name )
290
293
except theano .gof .fg .MissingInputError :
291
294
# The node failed, so we must add the node's parents to
@@ -295,10 +298,31 @@ def draw_values(params, point=None, size=None):
295
298
if node is not None and
296
299
node .name not in stored and
297
300
node not in params ])
298
- values = []
299
- for param in params :
300
- values .append (_draw_value (param , point = point , givens = givens .values (), size = size ))
301
- return values
301
+
302
+ # the below makes sure the graph is evaluated in order
303
+ # test_distributions_random::TestDrawValues::test_draw_order fails without it
304
+ params = dict (enumerate (params )) # some nodes are not hashable
305
+ evaluated = {}
306
+ to_eval = set ()
307
+ missing_inputs = set (params )
308
+ while to_eval or missing_inputs :
309
+ if to_eval == missing_inputs :
310
+ raise ValueError ('Cannot resolve inputs for {}' .format ([str (params [j ]) for j in to_eval ]))
311
+ to_eval = set (missing_inputs )
312
+ missing_inputs = set ()
313
+ for param_idx in to_eval :
314
+ param = params [param_idx ]
315
+ if hasattr (param , 'name' ) and param .name in givens :
316
+ evaluated [param_idx ] = givens [param .name ][1 ]
317
+ else :
318
+ try : # might evaluate in a bad order,
319
+ evaluated [param_idx ] = _draw_value (param , point = point , givens = givens .values (), size = size )
320
+ if isinstance (param , collections .Hashable ) and named_nodes_parents .get (param ):
321
+ givens [param .name ] = (param , evaluated [param_idx ])
322
+ except theano .gof .fg .MissingInputError :
323
+ missing_inputs .add (param_idx )
324
+
325
+ return [evaluated [j ] for j in params ] # set the order back
302
326
303
327
304
328
@memoize
@@ -356,43 +380,26 @@ def _draw_value(param, point=None, givens=None, size=None):
356
380
return point [param .name ]
357
381
elif hasattr (param , 'random' ) and param .random is not None :
358
382
return param .random (point = point , size = size )
383
+ elif (hasattr (param , 'distribution' ) and
384
+ hasattr (param .distribution , 'random' ) and
385
+ param .distribution .random is not None ):
386
+ return param .distribution .random (point = point , size = size )
359
387
else :
360
388
if givens :
361
389
variables , values = list (zip (* givens ))
362
390
else :
363
391
variables = values = []
364
392
func = _compile_theano_function (param , variables )
365
- return func (* values )
393
+ if size and values and not all (var .dshape == val .shape for var , val in zip (variables , values )):
394
+ return np .array ([func (* v ) for v in zip (* values )])
395
+ else :
396
+ return func (* values )
366
397
else :
367
398
raise ValueError ('Unexpected type in draw_value: %s' % type (param ))
368
399
369
400
370
- def broadcast_shapes (* args ):
371
- """Return the shape resulting from broadcasting multiple shapes.
372
- Represents numpy's broadcasting rules.
373
-
374
- Parameters
375
- ----------
376
- *args : array-like of int
377
- Tuples or arrays or lists representing the shapes of arrays to be broadcast.
378
-
379
- Returns
380
- -------
381
- Resulting shape or None if broadcasting is not possible.
382
- """
383
- x = list (np .atleast_1d (args [0 ])) if args else ()
384
- for arg in args [1 :]:
385
- y = list (np .atleast_1d (arg ))
386
- if len (x ) < len (y ):
387
- x , y = y , x
388
- x [- len (y ):] = [j if i == 1 else i if j == 1 else i if i == j else 0
389
- for i , j in zip (x [- len (y ):], y )]
390
- if not all (x ):
391
- return None
392
- return tuple (x )
393
-
394
-
395
- def infer_shape (shape ):
401
+ def to_tuple (shape ):
402
+ """Convert ints, arrays, and Nones to tuples"""
396
403
try :
397
404
shape = tuple (shape or ())
398
405
except TypeError : # If size is an int
@@ -401,27 +408,14 @@ def infer_shape(shape):
401
408
shape = tuple (shape )
402
409
return shape
403
410
404
-
405
- def reshape_sampled (sampled , size , dist_shape ):
406
- dist_shape = infer_shape (dist_shape )
407
- repeat_shape = infer_shape (size )
408
-
409
- if np .size (sampled ) == 1 or repeat_shape or dist_shape :
410
- return np .reshape (sampled , repeat_shape + dist_shape )
411
- else :
412
- return sampled
413
-
414
-
415
- def replicate_samples (generator , size , repeats , * args , ** kwargs ):
416
- n = int (np .prod (repeats ))
417
- if n == 1 :
418
- samples = generator (size = size , * args , ** kwargs )
419
- else :
420
- samples = np .array ([generator (size = size , * args , ** kwargs )
421
- for _ in range (n )])
422
- samples = np .reshape (samples , tuple (repeats ) + tuple (size ))
423
- return samples
424
-
411
+ def _is_one_d (dist_shape ):
412
+ if hasattr (dist_shape , 'dshape' ) and dist_shape .dshape in ((), (0 ,), (1 ,)):
413
+ return True
414
+ elif hasattr (dist_shape , 'shape' ) and dist_shape .shape in ((), (0 ,), (1 ,)):
415
+ return True
416
+ elif dist_shape == ():
417
+ return True
418
+ return False
425
419
426
420
def generate_samples (generator , * args , ** kwargs ):
427
421
"""Generate samples from the distribution of a random variable.
@@ -453,42 +447,60 @@ def generate_samples(generator, *args, **kwargs):
453
447
Any remaining *args and **kwargs are passed on to the generator function.
454
448
"""
455
449
dist_shape = kwargs .pop ('dist_shape' , ())
450
+ one_d = _is_one_d (dist_shape )
456
451
size = kwargs .pop ('size' , None )
457
452
broadcast_shape = kwargs .pop ('broadcast_shape' , None )
458
- params = args + tuple (kwargs .values ())
459
-
460
- if broadcast_shape is None :
461
- broadcast_shape = broadcast_shapes (* [np .atleast_1d (p ).shape for p in params
462
- if not isinstance (p , tuple )])
463
- if broadcast_shape == ():
464
- broadcast_shape = (1 ,)
453
+ if size is None :
454
+ size = 1
465
455
466
456
args = tuple (p [0 ] if isinstance (p , tuple ) else p for p in args )
457
+
467
458
for key in kwargs :
468
459
p = kwargs [key ]
469
460
kwargs [key ] = p [0 ] if isinstance (p , tuple ) else p
470
461
471
- if np .all (dist_shape [- len (broadcast_shape ):] == broadcast_shape ):
472
- prefix_shape = tuple (dist_shape [:- len (broadcast_shape )])
473
- else :
474
- prefix_shape = tuple (dist_shape )
475
-
476
- repeat_shape = infer_shape (size )
477
-
478
- if broadcast_shape == (1 ,) and prefix_shape == ():
479
- if size is not None :
480
- samples = generator (size = size , * args , ** kwargs )
462
+ if broadcast_shape is None :
463
+ inputs = args + tuple (kwargs .values ())
464
+ broadcast_shape = np .broadcast (* inputs ).shape # size of generator(size=1)
465
+
466
+ dist_shape = to_tuple (dist_shape )
467
+ broadcast_shape = to_tuple (broadcast_shape )
468
+ size_tup = to_tuple (size )
469
+
470
+ # All inputs are scalars, end up size (size_tup, dist_shape)
471
+ if broadcast_shape in {(), (0 ,), (1 ,)}:
472
+ samples = generator (size = size_tup + dist_shape , * args , ** kwargs )
473
+ # Inputs already have the right shape. Just get the right size.
474
+ elif broadcast_shape [- len (dist_shape ):] == dist_shape or len (dist_shape ) == 0 :
475
+ if size == 1 or (broadcast_shape == size_tup + dist_shape ):
476
+ samples = generator (size = broadcast_shape , * args , ** kwargs )
477
+ elif dist_shape == broadcast_shape :
478
+ samples = generator (size = size_tup + dist_shape , * args , ** kwargs )
481
479
else :
482
- samples = generator (size = 1 , * args , ** kwargs )
480
+ samples = None
481
+ # Args have been broadcast correctly, can just ask for the right shape out
482
+ elif dist_shape [- len (broadcast_shape ):] == broadcast_shape :
483
+ samples = generator (size = size_tup + dist_shape , * args , ** kwargs )
484
+ # Inputs have the right size, have to manually broadcast to the right dist_shape
485
+ elif broadcast_shape [:len (size_tup )] == size_tup :
486
+ suffix = broadcast_shape [len (size_tup ):] + dist_shape
487
+ samples = [generator (* args , ** kwargs ).reshape (size_tup + (1 ,)) for _ in range (np .prod (suffix , dtype = int ))]
488
+ samples = np .hstack (samples ).reshape (size_tup + suffix )
483
489
else :
484
- if size is not None :
485
- samples = replicate_samples (generator ,
486
- broadcast_shape ,
487
- repeat_shape + prefix_shape ,
488
- * args , ** kwargs )
489
- else :
490
- samples = replicate_samples (generator ,
491
- broadcast_shape ,
492
- prefix_shape ,
493
- * args , ** kwargs )
494
- return reshape_sampled (samples , size , dist_shape )
490
+ samples = None
491
+
492
+ if samples is None :
493
+ raise TypeError ('''Attempted to generate values with incompatible shapes:
494
+ size: {size}
495
+ dist_shape: {dist_shape}
496
+ broadcast_shape: {broadcast_shape}
497
+ ''' .format (size = size , dist_shape = dist_shape , broadcast_shape = broadcast_shape ))
498
+
499
+ # reshape samples here
500
+ if samples .shape [0 ] == 1 and size == 1 :
501
+ if len (samples .shape ) > len (dist_shape ) and samples .shape [- len (dist_shape ):] == dist_shape :
502
+ samples = samples .reshape (samples .shape [1 :])
503
+
504
+ if one_d and samples .shape [- 1 ] == 1 :
505
+ samples = samples .reshape (samples .shape [:- 1 ])
506
+ return np .asarray (samples )
0 commit comments