1
1
import threading
2
-
2
+ import six
3
3
import numpy as np
4
4
import theano
5
5
import theano .tensor as tt
@@ -172,17 +172,215 @@ def logpt(self):
172
172
return tt .sum (self .logp_elemwiset )
173
173
174
174
175
- class Model (Context , Factor ):
176
- """Encapsulates the variables and likelihood factors of a model."""
175
+ class InitContextMeta (type ):
176
+ """Metaclass that executes `__init__` of instance in it's context"""
177
+ def __call__ (cls , * args , ** kwargs ):
178
+ instance = cls .__new__ (cls , * args , ** kwargs )
179
+ with instance : # appends context
180
+ instance .__init__ (* args , ** kwargs )
181
+ return instance
182
+
183
+
184
+ def withparent (meth ):
185
+ """Helper wrapper that passes calls to parent's instance"""
186
+ def wrapped (self , * args , ** kwargs ):
187
+ res = meth (self , * args , ** kwargs )
188
+ if getattr (self , 'parent' , None ) is not None :
189
+ getattr (self .parent , meth .__name__ )(* args , ** kwargs )
190
+ return res
191
+ # Unfortunately functools wrapper fails
192
+ # when decorating built-in methods so we
193
+ # need to fix that improper behaviour
194
+ wrapped .__name__ = meth .__name__
195
+ return wrapped
196
+
197
+
198
+ class treelist (list ):
199
+ """A list that passes mutable extending operations used in Model
200
+ to parent list instance.
201
+ Extending treelist you will also extend its parent
202
+ """
203
+ def __init__ (self , iterable = (), parent = None ):
204
+ super (treelist , self ).__init__ (iterable )
205
+ assert isinstance (parent , list ) or parent is None
206
+ self .parent = parent
207
+ if self .parent is not None :
208
+ self .parent .extend (self )
209
+ # typechecking here works bad
210
+ append = withparent (list .append )
211
+ __iadd__ = withparent (list .__iadd__ )
212
+ extend = withparent (list .extend )
213
+
214
+ def tree_contains (self , item ):
215
+ if isinstance (self .parent , treedict ):
216
+ return (list .__contains__ (self , item ) or
217
+ self .parent .tree_contains (item ))
218
+ elif isinstance (self .parent , list ):
219
+ return (list .__contains__ (self , item ) or
220
+ self .parent .__contains__ (item ))
221
+ else :
222
+ return list .__contains__ (self , item )
223
+
224
+ def __setitem__ (self , key , value ):
225
+ raise NotImplementedError ('Method is removed as we are not'
226
+ ' able to determine '
227
+ 'appropriate logic for it' )
228
+
229
+ def __imul__ (self , other ):
230
+ t0 = len (self )
231
+ list .__imul__ (self , other )
232
+ if self .parent is not None :
233
+ self .parent .extend (self [t0 :])
234
+
235
+
236
+ class treedict (dict ):
237
+ """A dict that passes mutable extending operations used in Model
238
+ to parent dict instance.
239
+ Extending treedict you will also extend its parent
240
+ """
241
+ def __init__ (self , iterable = (), parent = None , ** kwargs ):
242
+ super (treedict , self ).__init__ (iterable , ** kwargs )
243
+ assert isinstance (parent , dict ) or parent is None
244
+ self .parent = parent
245
+ if self .parent is not None :
246
+ self .parent .update (self )
247
+ # typechecking here works bad
248
+ __setitem__ = withparent (dict .__setitem__ )
249
+ update = withparent (dict .update )
250
+
251
+ def tree_contains (self , item ):
252
+ # needed for `add_random_variable` method
253
+ if isinstance (self .parent , treedict ):
254
+ return (dict .__contains__ (self , item ) or
255
+ self .parent .tree_contains (item ))
256
+ elif isinstance (self .parent , dict ):
257
+ return (dict .__contains__ (self , item ) or
258
+ self .parent .__contains__ (item ))
259
+ else :
260
+ return dict .__contains__ (self , item )
261
+
262
+
263
+ class Model (six .with_metaclass (InitContextMeta , Context , Factor )):
264
+ """Encapsulates the variables and likelihood factors of a model.
265
+
266
+ Model class can be used for creating class based models. To create
267
+ a class based model you should inherit from `Model` and
268
+ override `__init__` with arbitrary definitions
269
+ (do not forget to call base class `__init__` first).
270
+
271
+ Parameters
272
+ ----------
273
+ name : str, default '' - name that will be used as prefix for
274
+ names of all random variables defined within model
275
+ model : Model, default None - instance of Model that is
276
+ supposed to be a parent for the new instance. If None,
277
+ context will be used. All variables defined within instance
278
+ will be passed to the parent instance. So that 'nested' model
279
+ contributes to the variables and likelihood factors of
280
+ parent model.
281
+
282
+ Examples
283
+ --------
284
+ # How to define a custom model
285
+ class CustomModel(Model):
286
+ # 1) override init
287
+ def __init__(self, mean=0, sd=1, name='', model=None):
288
+ # 2) call super's init first, passing model and name to it
289
+ # name will be prefix for all variables here
290
+ # if no name specified for model there will be no prefix
291
+ super(CustomModel, self).__init__(name, model)
292
+ # now you are in the context of instance,
293
+ # `modelcontext` will return self
294
+ # you can define variables in several ways
295
+ # note, that all variables will get model's name prefix
296
+
297
+ # 3) you can create variables with Var method
298
+ self.Var('v1', Normal.dist(mu=mean, sd=sd))
299
+ # this will create variable named like '{prefix_}v1'
300
+ # and assign attribute 'v1' to instance
301
+ # created variable can be accessed with self.v1 or self['v1']
302
+
303
+ # 4) this syntax will also work as we are in the context
304
+ # of instance itself, names are given as usual
305
+ Normal('v2', mu=mean, sd=sd)
306
+
307
+ # something more complex is allowed too
308
+ Normal('v3', mu=mean, sd=HalfCauchy('sd', beta=10, testval=1.))
309
+
310
+ # Deterministic variables can be used in usual way
311
+ Deterministic('v3_sq', self.v3 ** 2)
312
+ # Potentials too
313
+ Potential('p1', tt.constant(1))
314
+
315
+ # After defining a class CustomModel you can use it in several ways
316
+
317
+ # I:
318
+ # state the model within a context
319
+ with Model() as model:
320
+ CustomModel()
321
+ # arbitrary actions
322
+
323
+ # II:
324
+ # use new class as entering point in context
325
+ with CustomModel() as model:
326
+ Normal('new_normal_var', mu=1, sd=0)
327
+
328
+ # III:
329
+ # just get model instance with all that was defined in it
330
+ model = CustomModel()
331
+
332
+ # IV:
333
+ # use many custom models within one context
334
+ with Model() as model:
335
+ CustomModel(mean=1, name='first')
336
+ CustomModel(mean=2, name='second')
337
+ """
338
+ def __new__ (cls , * args , ** kwargs ):
339
+ # resolves the parent instance
340
+ instance = object .__new__ (cls )
341
+ if kwargs .get ('model' ) is not None :
342
+ instance ._parent = kwargs .get ('model' )
343
+ elif cls .get_contexts ():
344
+ instance ._parent = cls .get_contexts ()[- 1 ]
345
+ else :
346
+ instance ._parent = None
347
+ return instance
348
+
349
+ def __init__ (self , name = '' , model = None ):
350
+ self .name = name
351
+ if self .parent is not None :
352
+ self .named_vars = treedict (parent = self .parent .named_vars )
353
+ self .free_RVs = treelist (parent = self .parent .free_RVs )
354
+ self .observed_RVs = treelist (parent = self .parent .observed_RVs )
355
+ self .deterministics = treelist (parent = self .parent .deterministics )
356
+ self .potentials = treelist (parent = self .parent .potentials )
357
+ self .missing_values = treelist (parent = self .parent .missing_values )
358
+ else :
359
+ self .named_vars = treedict ()
360
+ self .free_RVs = treelist ()
361
+ self .observed_RVs = treelist ()
362
+ self .deterministics = treelist ()
363
+ self .potentials = treelist ()
364
+ self .missing_values = treelist ()
365
+
366
+ @property
367
+ def model (self ):
368
+ return self
369
+
370
+ @property
371
+ def parent (self ):
372
+ return self ._parent
373
+
374
+ @property
375
+ def root (self ):
376
+ model = self
377
+ while not model .isroot :
378
+ model = model .parent
379
+ return model
177
380
178
- def __init__ (self ):
179
- self .named_vars = {}
180
- self .free_RVs = []
181
- self .observed_RVs = []
182
- self .deterministics = []
183
- self .potentials = []
184
- self .missing_values = []
185
- self .model = self
381
+ @property
382
+ def isroot (self ):
383
+ return self .parent is None
186
384
187
385
@property
188
386
@memoize
@@ -275,6 +473,7 @@ def Var(self, name, dist, data=None):
275
473
-------
276
474
FreeRV or ObservedRV
277
475
"""
476
+ name = self .name_for (name )
278
477
if data is None :
279
478
if getattr (dist , "transform" , None ) is None :
280
479
var = FreeRV (name = name , distribution = dist , model = self )
@@ -312,15 +511,46 @@ def Var(self, name, dist, data=None):
312
511
313
512
def add_random_variable (self , var ):
314
513
"""Add a random variable to the named variables of the model."""
315
- if var . name in self .named_vars :
514
+ if self .named_vars . tree_contains ( var . name ) :
316
515
raise ValueError (
317
516
"Variable name {} already exists." .format (var .name ))
318
517
self .named_vars [var .name ] = var
319
- if not hasattr (self , var .name ):
320
- setattr (self , var .name , var )
518
+ if not hasattr (self , self .name_of (var .name )):
519
+ setattr (self , self .name_of (var .name ), var )
520
+
521
+ @property
522
+ def prefix (self ):
523
+ return '%s_' % self .name if self .name else ''
524
+
525
+ def name_for (self , name ):
526
+ """Checks if name has prefix and adds if needed
527
+ """
528
+ if self .prefix :
529
+ if not name .startswith (self .prefix ):
530
+ return '{}{}' .format (self .prefix , name )
531
+ else :
532
+ return name
533
+ else :
534
+ return name
535
+
536
+ def name_of (self , name ):
537
+ """Checks if name has prefix and deletes if needed
538
+ """
539
+ if not self .prefix or not name :
540
+ return name
541
+ elif name .startswith (self .prefix ):
542
+ return name [len (self .prefix ):]
543
+ else :
544
+ return name
321
545
322
546
def __getitem__ (self , key ):
323
- return self .named_vars [key ]
547
+ try :
548
+ return self .named_vars [key ]
549
+ except KeyError as e :
550
+ try :
551
+ return self .named_vars [self .name_for (key )]
552
+ except KeyError :
553
+ raise e
324
554
325
555
@memoize
326
556
def makefn (self , outs , mode = None , * args , ** kwargs ):
@@ -637,9 +867,10 @@ def Deterministic(name, var, model=None):
637
867
-------
638
868
n : var but with name name
639
869
"""
640
- var .name = name
641
- modelcontext (model ).deterministics .append (var )
642
- modelcontext (model ).add_random_variable (var )
870
+ model = modelcontext (model )
871
+ var .name = model .name_for (name )
872
+ model .deterministics .append (var )
873
+ model .add_random_variable (var )
643
874
return var
644
875
645
876
@@ -655,8 +886,9 @@ def Potential(name, var, model=None):
655
886
-------
656
887
var : var, with name attribute
657
888
"""
658
- var .name = name
659
- modelcontext (model ).potentials .append (var )
889
+ model = modelcontext (model )
890
+ var .name = model .name_for (name )
891
+ model .potentials .append (var )
660
892
return var
661
893
662
894
0 commit comments