Skip to content

Commit 68eb273

Browse files
committed
Broadcast dist in Censored with lower and upper when size is not specified
1 parent 3a6163a commit 68eb273

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

pymc/distributions/censored.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from aesara.tensor import TensorVariable
1919
from aesara.tensor.random.op import RandomVariable
2020

21+
from pymc.aesaraf import change_rv_size
2122
from pymc.distributions.distribution import SymbolicDistribution, _moment
2223
from pymc.util import check_dist_not_registered
2324

@@ -74,10 +75,13 @@ def dist(cls, dist, lower, upper, **kwargs):
7475

7576
@classmethod
7677
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
77-
if lower is None:
78-
lower = at.constant(-np.inf)
79-
if upper is None:
80-
upper = at.constant(np.inf)
78+
79+
lower = at.constant(-np.inf) if lower is None else at.as_tensor_variable(lower)
80+
upper = at.constant(np.inf) if upper is None else at.as_tensor_variable(upper)
81+
82+
# When size is not specified, dist may have to be broadcasted according to lower/upper
83+
dist_shape = size if size is not None else at.broadcast_shape(dist, lower, upper)
84+
dist = change_rv_size(dist, dist_shape)
8185

8286
# Censoring is achieved by clipping the base distribution between lower and upper
8387
rv_out = at.clip(dist, lower, upper)
@@ -88,8 +92,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
8892
rv_out.tag.lower = lower
8993
rv_out.tag.upper = upper
9094

91-
if size is not None:
92-
rv_out = cls.change_size(rv_out, size)
9395
if rngs is not None:
9496
rv_out = cls.change_rngs(rv_out, rngs)
9597

@@ -101,12 +103,10 @@ def ndim_supp(cls, *dist_params):
101103

102104
@classmethod
103105
def change_size(cls, rv, new_size, expand=False):
104-
dist_node = rv.tag.dist.owner
106+
dist = rv.tag.dist
105107
lower = rv.tag.lower
106108
upper = rv.tag.upper
107-
rng, old_size, dtype, *dist_params = dist_node.inputs
108-
new_size = new_size if not expand else tuple(new_size) + tuple(old_size)
109-
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
109+
new_dist = change_rv_size(dist, new_size, expand=expand)
110110
return cls.rv_op(new_dist, lower, upper)
111111

112112
@classmethod

pymc/tests/test_distributions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3381,6 +3381,18 @@ def test_change_size(self):
33813381
new_dist = pm.Censored.change_size(base_dist, (4,), expand=True)
33823382
assert new_dist.eval().shape == (4, 3, 2)
33833383

3384+
def test_dist_broadcasted_by_lower_upper(self):
3385+
x = pm.Censored.dist(pm.Normal.dist(), lower=np.zeros((2,)), upper=None)
3386+
assert tuple(x.owner.inputs[0].shape.eval()) == (2,)
3387+
3388+
x = pm.Censored.dist(pm.Normal.dist(), lower=np.zeros((2,)), upper=np.zeros((4, 2)))
3389+
assert tuple(x.owner.inputs[0].shape.eval()) == (4, 2)
3390+
3391+
x = pm.Censored.dist(
3392+
pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2))
3393+
)
3394+
assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2)
3395+
33843396

33853397
class TestLKJCholeskCov:
33863398
def test_dist(self):

0 commit comments

Comments
 (0)