@@ -277,6 +277,168 @@ def tree_contains(self, item):
277
277
return dict .__contains__ (self , item )
278
278
279
279
280
+ class ValueGradFunction (object ):
281
+ """Create a theano function that computes a value and its gradient.
282
+
283
+ Parameters
284
+ ----------
285
+ cost : theano variable
286
+ The value that we compute with its gradient.
287
+ grad_vars : list of named theano variables or None
288
+ The arguments with respect to which the gradient is computed.
289
+ extra_args : list of named theano variables or None
290
+ Other arguments of the function that are assumed constant. They
291
+ are stored in shared variables and can be set using
292
+ `set_extra_values`.
293
+ dtype : str, default=theano.config.floatX
294
+ The dtype of the arrays.
295
+ casting : {'no', 'equiv', 'save', 'same_kind', 'unsafe'}, default='no'
296
+ Casting rule for casting `grad_args` to the array dtype.
297
+ See `numpy.can_cast` for a description of the options.
298
+ Keep in mind that we cast the variables to the array *and*
299
+ back from the array dtype to the variable dtype.
300
+ kwargs
301
+ Extra arguments are passed on to `theano.function`.
302
+
303
+ Attributes
304
+ ----------
305
+ size : int
306
+ The number of elements in the parameter array.
307
+ profile : theano profiling object or None
308
+ The profiling object of the theano function that computes value and
309
+ gradient. This is None unless `profile=True` was set in the
310
+ kwargs.
311
+ """
312
+ def __init__ (self , cost , grad_vars , extra_vars = None , dtype = None ,
313
+ casting = 'no' , ** kwargs ):
314
+ if extra_vars is None :
315
+ extra_vars = []
316
+
317
+ names = [arg .name for arg in grad_vars + extra_vars ]
318
+ if any (name is None for name in names ):
319
+ raise ValueError ('Arguments must be named.' )
320
+ if len (set (names )) != len (names ):
321
+ raise ValueError ('Names of the arguments are not unique.' )
322
+
323
+ if cost .ndim > 0 :
324
+ raise ValueError ('Cost must be a scalar.' )
325
+
326
+ self ._grad_vars = grad_vars
327
+ self ._extra_vars = extra_vars
328
+ self ._extra_var_names = set (var .name for var in extra_vars )
329
+ self ._cost = cost
330
+ self ._ordering = ArrayOrdering (grad_vars )
331
+ self .size = self ._ordering .size
332
+ self ._extra_are_set = False
333
+ if dtype is None :
334
+ dtype = theano .config .floatX
335
+ self ._dtype = dtype
336
+ for var in self ._grad_vars :
337
+ if not np .can_cast (var .dtype , self ._dtype , casting ):
338
+ raise TypeError ('Invalid dtype for variable %s. Can not '
339
+ 'cast to %s with casting rule %s.'
340
+ % (var .name , self ._dtype , casting ))
341
+
342
+ givens = []
343
+ self ._extra_vars_shared = {}
344
+ for var in extra_vars :
345
+ shared = theano .shared (var .tag .test_value , var .name + '_shared__' )
346
+ self ._extra_vars_shared [var .name ] = shared
347
+ givens .append ((var , shared ))
348
+
349
+ self ._vars_joined , self ._cost_joined = self ._build_joined (
350
+ self ._cost , grad_vars , self ._ordering .vmap )
351
+
352
+ grad = tt .grad (self ._cost_joined , self ._vars_joined )
353
+
354
+ inputs = [self ._vars_joined ]
355
+
356
+ self ._theano_function = theano .function (
357
+ inputs , [self ._cost_joined , grad ], givens = givens , ** kwargs )
358
+
359
+ def set_extra_values (self , extra_vars ):
360
+ self ._extra_are_set = True
361
+ for var in self ._extra_vars :
362
+ self ._extra_vars_shared [var .name ].set_value (extra_vars [var .name ])
363
+
364
+ def get_extra_values (self ):
365
+ if not self ._extra_are_set :
366
+ raise ValueError ('Extra values are not set.' )
367
+
368
+ return {var .name : self ._extra_vars_shared [var .name ].get_value ()
369
+ for var in self ._extra_vars }
370
+
371
+ def __call__ (self , array , grad_out = None , extra_vars = None ):
372
+ if extra_vars is not None :
373
+ self .set_extra_values (extra_vars )
374
+
375
+ if not self ._extra_are_set :
376
+ raise ValueError ('Extra values are not set.' )
377
+
378
+ if array .shape != (self .size ,):
379
+ raise ValueError ('Invalid shape for array. Must be %s but is %s.'
380
+ % ((self .size ,), array .shape ))
381
+
382
+ if grad_out is None :
383
+ out = np .empty_like (array )
384
+ else :
385
+ out = grad_out
386
+
387
+ logp , dlogp = self ._theano_function (array )
388
+ if grad_out is None :
389
+ return logp , dlogp
390
+ else :
391
+ out [...] = dlogp
392
+ return logp
393
+
394
+ @property
395
+ def profile (self ):
396
+ """Profiling information of the underlying theano function."""
397
+ return self ._theano_function .profile
398
+
399
+ def dict_to_array (self , point ):
400
+ """Convert a dictionary with values for grad_vars to an array."""
401
+ array = np .empty (self .size , dtype = self ._dtype )
402
+ for varmap in self ._ordering .vmap :
403
+ array [varmap .slc ] = point [varmap .var ].ravel ().astype (self ._dtype )
404
+ return array
405
+
406
+ def array_to_dict (self , array ):
407
+ """Convert an array to a dictionary containing the grad_vars."""
408
+ if array .shape != (self .size ,):
409
+ raise ValueError ('Array should have shape (%s,) but has %s'
410
+ % (self .size , array .shape ))
411
+ if array .dtype != self ._dtype :
412
+ raise ValueError ('Array has invalid dtype. Should be %s but is %s'
413
+ % (self ._dtype , self ._dtype ))
414
+ point = {}
415
+ for varmap in self ._ordering .vmap :
416
+ data = array [varmap .slc ].reshape (varmap .shp )
417
+ point [varmap .var ] = data .astype (varmap .dtyp )
418
+
419
+ return point
420
+
421
+ def array_to_full_dict (self , array ):
422
+ """Convert an array to a dictionary with grad_vars and extra_vars."""
423
+ point = self .array_to_dict (array )
424
+ for name , var in self ._extra_vars_shared .items ():
425
+ point [name ] = var .get_value ()
426
+ return point
427
+
428
+ def _build_joined (self , cost , args , vmap ):
429
+ args_joined = tt .vector ('__args_joined' )
430
+ dtype = theano .config .floatX
431
+ args_joined .tag .test_value = np .zeros (self .size , dtype = dtype )
432
+
433
+ joined_slices = {}
434
+ for vmap in vmap :
435
+ sliced = args_joined [vmap .slc ].reshape (vmap .shp )
436
+ joined_slices [vmap .var ] = sliced
437
+
438
+ replace = {var : joined_slices [var .name ] for var in args }
439
+ return args_joined , theano .clone (cost , replace = replace )
440
+
441
+
280
442
class Model (six .with_metaclass (InitContextMeta , Context , Factor )):
281
443
"""Encapsulates the variables and likelihood factors of a model.
282
444
@@ -419,7 +581,6 @@ def bijection(self):
419
581
return bij
420
582
421
583
@property
422
- @memoize
423
584
def dict_to_array (self ):
424
585
return self .bijection .map
425
586
@@ -428,18 +589,27 @@ def ndim(self):
428
589
return sum (var .dsize for var in self .free_RVs )
429
590
430
591
@property
431
- @memoize
432
592
def logp_array (self ):
433
593
return self .bijection .mapf (self .fastlogp )
434
594
435
595
@property
436
- @memoize
437
596
def dlogp_array (self ):
438
597
vars = inputvars (self .cont_vars )
439
598
return self .bijection .mapf (self .fastdlogp (vars ))
440
599
600
+ def logp_dlogp_function (self , grad_vars = None , ** kwargs ):
601
+ if grad_vars is None :
602
+ grad_vars = list (typefilter (self .free_RVs , continuous_types ))
603
+ else :
604
+ for var in grad_vars :
605
+ if var .dtype not in continuous_types :
606
+ raise ValueError ("Can only compute the gradient of "
607
+ "continuous types: %s" % var )
608
+ varnames = [var .name for var in grad_vars ]
609
+ extra_vars = [var for var in self .free_RVs if var .name not in varnames ]
610
+ return ValueGradFunction (self .logpt , grad_vars , extra_vars , ** kwargs )
611
+
441
612
@property
442
- @memoize
443
613
def logpt (self ):
444
614
"""Theano scalar of log-probability of the model"""
445
615
with self :
@@ -595,7 +765,6 @@ def __getitem__(self, key):
595
765
except KeyError :
596
766
raise e
597
767
598
- @memoize
599
768
def makefn (self , outs , mode = None , * args , ** kwargs ):
600
769
"""Compiles a Theano function which returns `outs` and takes the variable
601
770
ancestors of `outs` as inputs.
0 commit comments