Skip to content

Commit a97206d

Browse files
authored
Discrete dists need int16 when run in float32 mode. (#2116)
* Discrete dists need int16 when run in float32 mode. * typo * int32 -> int64
1 parent 9af768c commit a97206d

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

pymc3/distributions/distribution.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,15 @@ def logp(self, x):
111111
class Discrete(Distribution):
112112
"""Base class for discrete distributions"""
113113

114-
def __init__(self, shape=(), dtype='int64', defaults=['mode'], *args, **kwargs):
115-
if dtype != 'int64':
116-
raise TypeError('Discrete classes expect dtype to be int64.')
114+
def __init__(self, shape=(), dtype=None, defaults=['mode'],
115+
*args, **kwargs):
116+
if dtype is None:
117+
if theano.config.floatX == 'float32':
118+
dtype = 'int16'
119+
else:
120+
dtype = 'int64'
121+
if dtype != 'int16' and dtype != 'int64':
122+
raise TypeError('Discrete classes expect dtype to be int16 or int64.')
117123
super(Discrete, self).__init__(
118124
shape, dtype, defaults=defaults, *args, **kwargs)
119125

0 commit comments

Comments
 (0)