Skip to content

Commit 40d0146

Browse files
ferrinetwiecki
authored andcommitted
ENH User model (#1525)
* Started to write Base class for pymc3.models * mode `add_var` to public api * Added some docstrings * Added some docstrings * added getitem and fixed a typo * added assertion check * added resolve var method * decided not to add resolve method * Added linear component * Docs fix * patsy's intercept is inited properly now * refactored code * updated docs * added possibility to init coefficients with random variables * added glm * refactored api, fixed formula init * refactored linear model, extended acceptable types * moved useful matrix and labels creation to utils file * code style * removed redundant evaluation of shape * refactored resolver for constructing matrix and labels * changed error message * changed signature of init * simplified utils any_to_tensor_and_labels code * tests for `any_to_tensor_and_labels` * added docstring for `any_to_tensor_and_labels` util * forgot to document return type in `any_to_tensor_and_labels` * refactored code for dict * dict tests fix(do not check labels there) * added access to random vars of model * added a shortcut for all variables so there is a unified way to get them * added default priors for linear model * update docs for linear * refactored UserModel api, made it more similar to pm.Model class * Lots of refactoring, tests for base class, more plain api design * deleted unused module variable * fixed some typos in docstring * Refactored pm.Model class, now it is ready for inheritance * Added documentation for Model class * Small typo in docstring * nested contains for treedict (needed for add_random_variable) * More accurate duplicate implementation of treedict/treelist * refactored treedict/treelist * changed `__imul__` of treelist * added `root` property and `isroot` indicator for base model * protect `parent` and `model` attributes from violation * travis' python2 did not fail on bad syntax(maybe it's too new), fixed * decided not to use functools wrapper Unfortunately functools wrapper fails when decorating built-in methods so I need to fix that improper behaviour. Some bad but needed tricks were implemented * Added models package to setup script * Refactor utils * Fix some typos in pm.model
1 parent 208aa79 commit 40d0146

File tree

9 files changed

+829
-22
lines changed

9 files changed

+829
-22
lines changed

pymc3/model.py

Lines changed: 252 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import threading
2-
2+
import six
33
import numpy as np
44
import theano
55
import theano.tensor as tt
@@ -172,17 +172,215 @@ def logpt(self):
172172
return tt.sum(self.logp_elemwiset)
173173

174174

175-
class Model(Context, Factor):
176-
"""Encapsulates the variables and likelihood factors of a model."""
175+
class InitContextMeta(type):
176+
"""Metaclass that executes `__init__` of instance in it's context"""
177+
def __call__(cls, *args, **kwargs):
178+
instance = cls.__new__(cls, *args, **kwargs)
179+
with instance: # appends context
180+
instance.__init__(*args, **kwargs)
181+
return instance
182+
183+
184+
def withparent(meth):
185+
"""Helper wrapper that passes calls to parent's instance"""
186+
def wrapped(self, *args, **kwargs):
187+
res = meth(self, *args, **kwargs)
188+
if getattr(self, 'parent', None) is not None:
189+
getattr(self.parent, meth.__name__)(*args, **kwargs)
190+
return res
191+
# Unfortunately functools wrapper fails
192+
# when decorating built-in methods so we
193+
# need to fix that improper behaviour
194+
wrapped.__name__ = meth.__name__
195+
return wrapped
196+
197+
198+
class treelist(list):
199+
"""A list that passes mutable extending operations used in Model
200+
to parent list instance.
201+
Extending treelist you will also extend its parent
202+
"""
203+
def __init__(self, iterable=(), parent=None):
204+
super(treelist, self).__init__(iterable)
205+
assert isinstance(parent, list) or parent is None
206+
self.parent = parent
207+
if self.parent is not None:
208+
self.parent.extend(self)
209+
# typechecking here works bad
210+
append = withparent(list.append)
211+
__iadd__ = withparent(list.__iadd__)
212+
extend = withparent(list.extend)
213+
214+
def tree_contains(self, item):
215+
if isinstance(self.parent, treedict):
216+
return (list.__contains__(self, item) or
217+
self.parent.tree_contains(item))
218+
elif isinstance(self.parent, list):
219+
return (list.__contains__(self, item) or
220+
self.parent.__contains__(item))
221+
else:
222+
return list.__contains__(self, item)
223+
224+
def __setitem__(self, key, value):
225+
raise NotImplementedError('Method is removed as we are not'
226+
' able to determine '
227+
'appropriate logic for it')
228+
229+
def __imul__(self, other):
230+
t0 = len(self)
231+
list.__imul__(self, other)
232+
if self.parent is not None:
233+
self.parent.extend(self[t0:])
234+
235+
236+
class treedict(dict):
237+
"""A dict that passes mutable extending operations used in Model
238+
to parent dict instance.
239+
Extending treedict you will also extend its parent
240+
"""
241+
def __init__(self, iterable=(), parent=None, **kwargs):
242+
super(treedict, self).__init__(iterable, **kwargs)
243+
assert isinstance(parent, dict) or parent is None
244+
self.parent = parent
245+
if self.parent is not None:
246+
self.parent.update(self)
247+
# typechecking here works bad
248+
__setitem__ = withparent(dict.__setitem__)
249+
update = withparent(dict.update)
250+
251+
def tree_contains(self, item):
252+
# needed for `add_random_variable` method
253+
if isinstance(self.parent, treedict):
254+
return (dict.__contains__(self, item) or
255+
self.parent.tree_contains(item))
256+
elif isinstance(self.parent, dict):
257+
return (dict.__contains__(self, item) or
258+
self.parent.__contains__(item))
259+
else:
260+
return dict.__contains__(self, item)
261+
262+
263+
class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
264+
"""Encapsulates the variables and likelihood factors of a model.
265+
266+
Model class can be used for creating class based models. To create
267+
a class based model you should inherit from `Model` and
268+
override `__init__` with arbitrary definitions
269+
(do not forget to call base class `__init__` first).
270+
271+
Parameters
272+
----------
273+
name : str, default '' - name that will be used as prefix for
274+
names of all random variables defined within model
275+
model : Model, default None - instance of Model that is
276+
supposed to be a parent for the new instance. If None,
277+
context will be used. All variables defined within instance
278+
will be passed to the parent instance. So that 'nested' model
279+
contributes to the variables and likelihood factors of
280+
parent model.
281+
282+
Examples
283+
--------
284+
# How to define a custom model
285+
class CustomModel(Model):
286+
# 1) override init
287+
def __init__(self, mean=0, sd=1, name='', model=None):
288+
# 2) call super's init first, passing model and name to it
289+
# name will be prefix for all variables here
290+
# if no name specified for model there will be no prefix
291+
super(CustomModel, self).__init__(name, model)
292+
# now you are in the context of instance,
293+
# `modelcontext` will return self
294+
# you can define variables in several ways
295+
# note, that all variables will get model's name prefix
296+
297+
# 3) you can create variables with Var method
298+
self.Var('v1', Normal.dist(mu=mean, sd=sd))
299+
# this will create variable named like '{prefix_}v1'
300+
# and assign attribute 'v1' to instance
301+
# created variable can be accessed with self.v1 or self['v1']
302+
303+
# 4) this syntax will also work as we are in the context
304+
# of instance itself, names are given as usual
305+
Normal('v2', mu=mean, sd=sd)
306+
307+
# something more complex is allowed too
308+
Normal('v3', mu=mean, sd=HalfCauchy('sd', beta=10, testval=1.))
309+
310+
# Deterministic variables can be used in usual way
311+
Deterministic('v3_sq', self.v3 ** 2)
312+
# Potentials too
313+
Potential('p1', tt.constant(1))
314+
315+
# After defining a class CustomModel you can use it in several ways
316+
317+
# I:
318+
# state the model within a context
319+
with Model() as model:
320+
CustomModel()
321+
# arbitrary actions
322+
323+
# II:
324+
# use new class as entering point in context
325+
with CustomModel() as model:
326+
Normal('new_normal_var', mu=1, sd=0)
327+
328+
# III:
329+
# just get model instance with all that was defined in it
330+
model = CustomModel()
331+
332+
# IV:
333+
# use many custom models within one context
334+
with Model() as model:
335+
CustomModel(mean=1, name='first')
336+
CustomModel(mean=2, name='second')
337+
"""
338+
def __new__(cls, *args, **kwargs):
339+
# resolves the parent instance
340+
instance = object.__new__(cls)
341+
if kwargs.get('model') is not None:
342+
instance._parent = kwargs.get('model')
343+
elif cls.get_contexts():
344+
instance._parent = cls.get_contexts()[-1]
345+
else:
346+
instance._parent = None
347+
return instance
348+
349+
def __init__(self, name='', model=None):
350+
self.name = name
351+
if self.parent is not None:
352+
self.named_vars = treedict(parent=self.parent.named_vars)
353+
self.free_RVs = treelist(parent=self.parent.free_RVs)
354+
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
355+
self.deterministics = treelist(parent=self.parent.deterministics)
356+
self.potentials = treelist(parent=self.parent.potentials)
357+
self.missing_values = treelist(parent=self.parent.missing_values)
358+
else:
359+
self.named_vars = treedict()
360+
self.free_RVs = treelist()
361+
self.observed_RVs = treelist()
362+
self.deterministics = treelist()
363+
self.potentials = treelist()
364+
self.missing_values = treelist()
365+
366+
@property
367+
def model(self):
368+
return self
369+
370+
@property
371+
def parent(self):
372+
return self._parent
373+
374+
@property
375+
def root(self):
376+
model = self
377+
while not model.isroot:
378+
model = model.parent
379+
return model
177380

178-
def __init__(self):
179-
self.named_vars = {}
180-
self.free_RVs = []
181-
self.observed_RVs = []
182-
self.deterministics = []
183-
self.potentials = []
184-
self.missing_values = []
185-
self.model = self
381+
@property
382+
def isroot(self):
383+
return self.parent is None
186384

187385
@property
188386
@memoize
@@ -275,6 +473,7 @@ def Var(self, name, dist, data=None):
275473
-------
276474
FreeRV or ObservedRV
277475
"""
476+
name = self.name_for(name)
278477
if data is None:
279478
if getattr(dist, "transform", None) is None:
280479
var = FreeRV(name=name, distribution=dist, model=self)
@@ -312,15 +511,46 @@ def Var(self, name, dist, data=None):
312511

313512
def add_random_variable(self, var):
314513
"""Add a random variable to the named variables of the model."""
315-
if var.name in self.named_vars:
514+
if self.named_vars.tree_contains(var.name):
316515
raise ValueError(
317516
"Variable name {} already exists.".format(var.name))
318517
self.named_vars[var.name] = var
319-
if not hasattr(self, var.name):
320-
setattr(self, var.name, var)
518+
if not hasattr(self, self.name_of(var.name)):
519+
setattr(self, self.name_of(var.name), var)
520+
521+
@property
522+
def prefix(self):
523+
return '%s_' % self.name if self.name else ''
524+
525+
def name_for(self, name):
526+
"""Checks if name has prefix and adds if needed
527+
"""
528+
if self.prefix:
529+
if not name.startswith(self.prefix):
530+
return '{}{}'.format(self.prefix, name)
531+
else:
532+
return name
533+
else:
534+
return name
535+
536+
def name_of(self, name):
537+
"""Checks if name has prefix and deletes if needed
538+
"""
539+
if not self.prefix or not name:
540+
return name
541+
elif name.startswith(self.prefix):
542+
return name[len(self.prefix):]
543+
else:
544+
return name
321545

322546
def __getitem__(self, key):
323-
return self.named_vars[key]
547+
try:
548+
return self.named_vars[key]
549+
except KeyError as e:
550+
try:
551+
return self.named_vars[self.name_for(key)]
552+
except KeyError:
553+
raise e
324554

325555
@memoize
326556
def makefn(self, outs, mode=None, *args, **kwargs):
@@ -637,9 +867,10 @@ def Deterministic(name, var, model=None):
637867
-------
638868
n : var but with name name
639869
"""
640-
var.name = name
641-
modelcontext(model).deterministics.append(var)
642-
modelcontext(model).add_random_variable(var)
870+
model = modelcontext(model)
871+
var.name = model.name_for(name)
872+
model.deterministics.append(var)
873+
model.add_random_variable(var)
643874
return var
644875

645876

@@ -655,8 +886,9 @@ def Potential(name, var, model=None):
655886
-------
656887
var : var, with name attribute
657888
"""
658-
var.name = name
659-
modelcontext(model).potentials.append(var)
889+
model = modelcontext(model)
890+
var.name = model.name_for(name)
891+
model.potentials.append(var)
660892
return var
661893

662894

pymc3/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .linear import LinearComponent, Glm
2+
3+
__all__ = [
4+
'LinearComponent',
5+
'Glm'
6+
]

0 commit comments

Comments
 (0)