Skip to content

Commit bdd3dc0

Browse files
ColCarrolltwiecki
authored andcommitted
Try to fix numpy warnings (#3215)
1 parent 14b3fd0 commit bdd3dc0

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

pymc3/distributions/distribution.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -420,20 +420,16 @@ def _draw_value(param, point=None, givens=None, size=None):
420420

421421
def to_tuple(shape):
422422
"""Convert ints, arrays, and Nones to tuples"""
423-
try:
424-
shape = tuple(shape or ())
425-
except TypeError: # If size is an int
426-
shape = tuple((shape,))
427-
except ValueError: # If size is np.array
428-
shape = tuple(shape)
429-
return shape
423+
if shape is None:
424+
return tuple()
425+
return tuple(np.atleast_1d(shape))
430426

431427
def _is_one_d(dist_shape):
432428
if hasattr(dist_shape, 'dshape') and dist_shape.dshape in ((), (0,), (1,)):
433429
return True
434430
elif hasattr(dist_shape, 'shape') and dist_shape.shape in ((), (0,), (1,)):
435431
return True
436-
elif dist_shape == ():
432+
elif to_tuple(dist_shape) == ():
437433
return True
438434
return False
439435

0 commit comments

Comments
 (0)