Skip to content

Commit 8fda70c

Browse files
committed
1 parent 8223a02 commit 8fda70c

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

pymc3/distributions/discrete.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,12 @@ def logp(self, value):
998998

999999
if p.ndim > 1:
10001000
pattern = (p.ndim - 1,) + tuple(range(p.ndim - 1))
1001-
a = tt.log(p.dimshuffle(pattern)[value_clip])
1001+
a = tt.log(
1002+
tt.choose(
1003+
value_clip,
1004+
p.dimshuffle(pattern),
1005+
)
1006+
)
10021007
else:
10031008
a = tt.log(p[value_clip])
10041009

@@ -1570,13 +1575,13 @@ def __init__(self, eta, cutpoints, *args, **kwargs):
15701575
self.eta = tt.as_tensor_variable(floatX(eta))
15711576
self.cutpoints = tt.as_tensor_variable(cutpoints)
15721577

1573-
pa = sigmoid(tt.shape_padleft(self.cutpoints) - tt.shape_padright(self.eta))
1578+
pa = sigmoid(self.cutpoints - tt.shape_padright(self.eta))
15741579
p_cum = tt.concatenate([
1575-
tt.zeros_like(tt.shape_padright(pa[:, 0])),
1580+
tt.zeros_like(tt.shape_padright(pa[..., 0])),
15761581
pa,
1577-
tt.ones_like(tt.shape_padright(pa[:, 0]))
1582+
tt.ones_like(tt.shape_padright(pa[..., 0]))
15781583
], axis=-1)
1579-
p = p_cum[:, 1:] - p_cum[:, :-1]
1584+
p = p_cum[..., 1:] - p_cum[..., :-1]
15801585

15811586
super().__init__(p=p, *args, **kwargs)
15821587

0 commit comments

Comments
 (0)