Skip to content

Commit 9f99178

Browse files
fixed (some) handling of symbolic scalar shapes
1 parent 69c6ffe commit 9f99178

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

pymc3/distributions/distribution.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,32 @@ def _as_tensor_shape_variable(var):
4242
return res
4343

4444

45+
def _as_tensor_shape_variable(var):
46+
""" Just a collection of useful shape stuff from
47+
`_infer_ndim_bcast` """
48+
49+
if var is None:
50+
return T.constant([], dtype='int64')
51+
52+
res = var
53+
if isinstance(res, (tuple, list)):
54+
if len(res) == 0:
55+
return T.constant([], dtype='int64')
56+
res = T.as_tensor_variable(res, ndim=1)
57+
58+
else:
59+
if res.ndim != 1:
60+
raise TypeError("shape must be a vector or list of scalar, got\
61+
'%s'" % res)
62+
63+
if (not (res.dtype.startswith('int') or
64+
res.dtype.startswith('uint'))):
65+
66+
raise TypeError('shape must be an integer vector or list',
67+
res.dtype)
68+
return res
69+
70+
4571
class Distribution(object):
4672
"""Statistical distribution"""
4773
def __new__(cls, name, *args, **kwargs):

0 commit comments

Comments
 (0)