Skip to content

Commit b3fb802

Browse files
committed
Ignore explicit size or shape of None
1 parent 5cf7bae commit b3fb802

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

pymc/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def __new__(
281281

282282
# Preference is given to size or shape. If not specified, we rely on dims and
283283
# finally, observed, to determine the shape of the variable.
284-
if not ("size" in kwargs or "shape" in kwargs):
284+
if kwargs.get("size") is None and kwargs.get("shape") is None:
285285
if dims is not None:
286286
kwargs["shape"] = shape_from_dims(dims, model)
287287
elif observed is not None:

pymc/tests/distributions/test_shape_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test_broadcast_dist_samples_to(self, samples_to_broadcast_to):
218218
broadcast_dist_samples_to(to_shape, samples, size=size)
219219

220220

221-
class TestShapeDimsSize:
221+
class TestSizeShapeDimsObserved:
222222
@pytest.mark.parametrize("param_shape", [(), (2,)])
223223
@pytest.mark.parametrize("batch_shape", [(), (3,)])
224224
@pytest.mark.parametrize(
@@ -465,6 +465,13 @@ def test_size_from_observed_rng_update(self):
465465
# draw, would match the first value of the second draw
466466
assert fn()[1] != fn()[0]
467467

468+
def test_explicit_size_shape_none(self):
469+
with pm.Model() as m:
470+
x = pm.Normal("x", shape=None, observed=[1, 2, 3])
471+
y = pm.Normal("y", size=None, observed=[1, 2, 3, 4])
472+
assert x.shape.eval().item() == 3
473+
assert y.shape.eval().item() == 4
474+
468475

469476
def test_rv_size_is_none():
470477
rv = pm.Normal.dist(0, 1, size=None)

0 commit comments

Comments
 (0)