Skip to content

Commit 88300b5

Browse files
committed
Grouped types into theano_constant tuple
1 parent 98635eb commit 88300b5

File tree

2 files changed

+32
-25
lines changed

2 files changed

+32
-25
lines changed

pymc3/distributions/distribution.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Model, get_named_nodes_and_relations, FreeRV,
1010
ObservedRV, MultiObservedRV, Context, InitContextMeta
1111
)
12-
from ..vartypes import string_types
12+
from ..vartypes import string_types, theano_constant
1313
from .shape_utils import (
1414
to_tuple,
1515
get_broadcastable_dist_samples,
@@ -92,7 +92,7 @@ def getattr_value(self, val):
9292
if isinstance(val, tt.TensorVariable):
9393
return val.tag.test_value
9494

95-
if isinstance(val, (tt.TensorConstant, theano.gof.graph.Constant)):
95+
if isinstance(val, theano_constant):
9696
return val.value
9797

9898
return val
@@ -502,8 +502,7 @@ def __init__(self):
502502
def is_fast_drawable(var):
503503
return isinstance(var, (numbers.Number,
504504
np.ndarray,
505-
tt.TensorConstant,
506-
theano.gof.graph.Constant,
505+
theano_constant,
507506
tt.sharedvar.SharedVariable))
508507

509508

@@ -593,8 +592,7 @@ def draw_values(params, point=None, size=None):
593592
if (next_, size) in drawn:
594593
# If the node already has a givens value, skip it
595594
continue
596-
elif isinstance(next_, (tt.TensorConstant,
597-
theano.gof.graph.Constant,
595+
elif isinstance(next_, (theano_constant,
598596
tt.sharedvar.SharedVariable)):
599597
# If the node is a theano.tensor.TensorConstant or a
600598
# theano.tensor.sharedvar.SharedVariable, its value will be
@@ -785,7 +783,7 @@ def _draw_value(param, point=None, givens=None, size=None):
785783
"""
786784
if isinstance(param, (numbers.Number, np.ndarray)):
787785
return param
788-
elif isinstance(param, (tt.TensorConstant, theano.gof.graph.Constant)):
786+
elif isinstance(param, theano_constant):
789787
return param.value
790788
elif isinstance(param, tt.sharedvar.SharedVariable):
791789
return param.get_value()

pymc3/vartypes.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
1-
__all__ = ['bool_types', 'int_types', 'float_types', 'complex_types', 'continuous_types',
2-
'discrete_types', 'typefilter', 'isgenerator']
3-
4-
bool_types = set(['int8'])
5-
6-
int_types = set(['int8',
7-
'int16',
8-
'int32',
9-
'int64',
10-
'uint8',
11-
'uint16',
12-
'uint32',
13-
'uint64'])
14-
float_types = set(['float32',
15-
'float64'])
16-
complex_types = set(['complex64',
17-
'complex128'])
1+
from theano.tensor import Constant as tensor_constant
2+
from theano.gof.graph import Constant as graph_constant
3+
4+
5+
__all__ = [
6+
"bool_types",
7+
"int_types",
8+
"float_types",
9+
"complex_types",
10+
"continuous_types",
11+
"discrete_types",
12+
"typefilter",
13+
"isgenerator",
14+
"theano_constant",
15+
]
16+
17+
bool_types = set(["int8"])
18+
19+
int_types = set(
20+
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
21+
)
22+
float_types = set(["float32", "float64"])
23+
complex_types = set(["complex64", "complex128"])
1824
continuous_types = float_types | complex_types
1925
discrete_types = bool_types | int_types
2026

@@ -27,4 +33,7 @@ def typefilter(vars, types):
2733

2834

2935
def isgenerator(obj):
30-
return hasattr(obj, '__next__')
36+
return hasattr(obj, "__next__")
37+
38+
39+
theano_constant = (tensor_constant, graph_constant)

0 commit comments

Comments
 (0)