Skip to content

Commit e6e1b4a

Browse files
forced derived shape components to be tuples
1 parent 9563d8a commit e6e1b4a

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

pymc3/distributions/multivariate.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def __init__(self, mu, tau, ndim=None, size=None, dtype=None, *args,
102102

103103
self.median = self.mode = self.mean = self.mu
104104
# TODO: How do we want to use ndim?
105-
shape_supp = self.mu.shape[-1]
106-
shape_ind = self.mu.shape[:-1]
105+
shape_supp = (self.mu.shape[-1],)
106+
shape_ind = tuple(self.mu.shape[:-1])
107107

108108
if self.mu.ndim > 0:
109109
bcast = (False,) * (1 + tt.get_vector_length(shape_ind))
@@ -261,8 +261,8 @@ def __init__(self, a, transform=transforms.stick_breaking, ndim=None, size=None,
261261
self.dist_params = (self.a,)
262262

263263
# TODO: How do we want to use ndim?
264-
shape_supp = self.a.shape[-1]
265-
shape_ind = self.a.shape[:-1]
264+
shape_supp = (self.a.shape[-1],)
265+
shape_ind = tuple(self.a.shape[:-1])
266266

267267
# FIXME: this isn't correct/ideal
268268
if self.a.ndim > 0:
@@ -358,7 +358,7 @@ def __init__(self, n, p, ndim=None, size=None, dtype=None, *args,
358358

359359
# TODO: check that n == len(p)?
360360
# TODO: How do we want to use ndim?
361-
shape_supp = self.n
361+
shape_supp = (self.n,)
362362
shape_ind = ()
363363

364364
# FIXME: this isn't correct/ideal
@@ -500,12 +500,11 @@ def __init__(self, n, V, ndim=None, size=None, dtype=None, *args,
500500
self.dist_params = (self.n, self.V)
501501

502502
# TODO: How do we want to use ndim?
503-
shape_supp = self.V.shape[-1]
503+
shape_supp = (self.V.shape[-1],)
504504
shape_ind = ()
505505

506-
self.mode = tt.switch(1 * (self.n >= shape_supp + 1), (self.n -
507-
shape_supp - 1)
508-
* self.V, np.nan)
506+
self.mode = tt.switch(1 * (self.n >= shape_supp + 1),
507+
(self.n - shape_supp - 1) * self.V, np.nan)
509508
self.mean = self.n * self.V
510509

511510
# FIXME: this isn't correct/ideal
@@ -662,19 +661,18 @@ def __init__(self, n, p, ndim=None, size=None, dtype=None, *args, **kwargs):
662661
self.n = tt.as_tensor_variable(n, ndim=0)
663662
self.p = tt.as_tensor_variable(p, ndim=0)
664663

664+
self.dist_params = (self.n, self.p)
665+
665666
# TODO: How do we want to use ndim?
666667
n_elem = (self.p * (self.p - 1) / 2)
667-
shape_supp = n_elem
668+
shape_supp = (n_elem,)
668669
self.mean = tt.zeros(n_elem)
669670

670671
# FIXME: triu, bcast, etc.
671672
self.tri_index = tt.zeros((self.p, self.p), dtype=int)
672673
self.tri_index[tt.triu(self.p, k=1)] = tt.arange(n_elem)
673674
self.tri_index[tt.triu(self.p, k=1)[::-1]] = tt.arange(n_elem)
674675

675-
self.dist_params = (self.n, self.p)
676-
677-
# TODO: do this correctly; what about replications?
678676
shape_ind = ()
679677

680678
# FIXME: this isn't correct/ideal

0 commit comments

Comments
 (0)