Skip to content

Commit ab5f44f

Browse files
Add Ellipsis-support for the shape kwarg
1 parent bbf8624 commit ab5f44f

File tree

3 files changed

+42
-27
lines changed

3 files changed

+42
-27
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
1212
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4696](https://github.com/pymc-devs/pymc3/pull/4696)):
1313
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Numeric entries in `shape` restrict the model variable to the exact length and re-sizing is no longer possible.
14-
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects. An `Ellipsis` (`...`) in the last position of `dims` can be used as short-hand notation for implied dimensions.
14+
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects.
1515
- The `size` kwarg behaves like it does in Aesara/NumPy. For univariate RVs it is the same as `shape`, but for multivariate RVs it depends on how the RV implements broadcasting to dimensionality greater than `RVOp.ndim_supp`.
16+
- An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions.
1617
- Add `logcdf` method to Kumaraswamy distribution (see [#4706](https://github.com/pymc-devs/pymc3/pull/4706)).
1718
- ...
1819

pymc3/distributions/distribution.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ def dist(
378378
The inputs to the `RandomVariable` `Op`.
379379
shape : int, tuple, Variable, optional
380380
A tuple of sizes for each dimension of the new RV.
381+
382+
An Ellipsis (...) may be inserted in the last position to short-hand refer to
383+
all the dimensions that the RV would get if no shape/size/dims were passed at all.
381384
size : int, tuple, Variable, optional
382385
For creating the RV like in Aesara/NumPy.
383386
testval : optional
@@ -404,9 +407,16 @@ def dist(
404407
create_size = None
405408

406409
if shape is not None:
407-
ndim_expected = len(tuple(shape))
408-
ndim_batch = ndim_expected - ndim_supp
409-
create_size = tuple(shape)[:ndim_batch]
410+
if Ellipsis in shape:
411+
# Ellipsis short-hands all implied dimensions. Therefore
412+
# we don't know how many dimensions to expect.
413+
ndim_expected = ndim_batch = None
414+
# Create the RV with its implied shape and resize later
415+
create_size = None
416+
else:
417+
ndim_expected = len(tuple(shape))
418+
ndim_batch = ndim_expected - ndim_supp
419+
create_size = tuple(shape)[:ndim_batch]
410420
elif size is not None:
411421
ndim_expected = ndim_supp + len(tuple(size))
412422
ndim_batch = ndim_expected - ndim_supp
@@ -419,21 +429,25 @@ def dist(
419429
ndims_unexpected = ndim_actual != ndim_expected
420430

421431
if shape is not None and ndims_unexpected:
422-
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
423-
# Recreate the RV without passing `size` to created it with just the implied dimensions.
424-
rv_out = cls.rv_op(*dist_params, size=None, **kwargs)
425-
426-
# Now resize by the "extra" dimensions that were not implied from support and parameters
427-
if rv_out.ndim < ndim_expected:
428-
expand_shape = shape[: ndim_expected - rv_out.ndim]
429-
rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True)
430-
if not rv_out.ndim == ndim_expected:
431-
raise ShapeError(
432-
f"Failed to create the RV with the expected dimensionality. "
433-
f"This indicates a severe problem. Please open an issue.",
434-
actual=ndim_actual,
435-
expected=ndim_batch + ndim_supp,
436-
)
432+
if Ellipsis in shape:
433+
# Resize and we're done!
434+
rv_out = change_rv_size(rv_var=rv_out, new_size=shape[:-1], expand=True)
435+
else:
436+
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
437+
# Recreate the RV without passing `size` to created it with just the implied dimensions.
438+
rv_out = cls.rv_op(*dist_params, size=None, **kwargs)
439+
440+
# Now resize by any remaining "extra" dimensions that were not implied from support and parameters
441+
if rv_out.ndim < ndim_expected:
442+
expand_shape = shape[: ndim_expected - rv_out.ndim]
443+
rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True)
444+
if not rv_out.ndim == ndim_expected:
445+
raise ShapeError(
446+
f"Failed to create the RV with the expected dimensionality. "
447+
f"This indicates a severe problem. Please open an issue.",
448+
actual=ndim_actual,
449+
expected=ndim_batch + ndim_supp,
450+
)
437451

438452
# Warn about the edge cases where the RV Op creates more dimensions than
439453
# it should based on `size` and `RVOp.ndim_supp`.

pymc3/tests/test_shape_handling.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class TestShapeDimsSize:
236236
[
237237
"implicit",
238238
"shape",
239-
# "shape...",
239+
"shape...",
240240
"dims",
241241
"dims...",
242242
"size",
@@ -273,9 +273,9 @@ def test_param_and_batch_shape_combos(
273273
if parametrization == "shape":
274274
rv = pm.Normal("rv", mu=mu, shape=batch_shape + param_shape)
275275
assert rv.eval().shape == expected_shape
276-
# elif parametrization == "shape...":
277-
# rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
278-
# assert rv.eval().shape == batch_shape + param_shape
276+
elif parametrization == "shape...":
277+
rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
278+
assert rv.eval().shape == batch_shape + param_shape
279279
elif parametrization == "dims":
280280
rv = pm.Normal("rv", mu=mu, dims=batch_dims + param_dims)
281281
assert rv.eval().shape == expected_shape
@@ -376,7 +376,7 @@ def test_dist_api_works(self):
376376
pm.Normal.dist(mu=mu, dims=("town",))
377377
assert pm.Normal.dist(mu=mu, shape=(3,)).eval().shape == (3,)
378378
assert pm.Normal.dist(mu=mu, shape=(5, 3)).eval().shape == (5, 3)
379-
# assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
379+
assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
380380
assert pm.Normal.dist(mu=mu, size=(3,)).eval().shape == (3,)
381381
assert pm.Normal.dist(mu=mu, size=(4, 3)).eval().shape == (4, 3)
382382

@@ -402,9 +402,9 @@ def test_mvnormal_shape_size_difference(self):
402402
assert rv.ndim == 3
403403
assert tuple(rv.shape.eval()) == (5, 4, 3)
404404

405-
# rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...))
406-
# assert rv.ndim == 5
407-
# assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2)
405+
rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...))
406+
assert rv.ndim == 5
407+
assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2)
408408

409409
with pytest.warns(None):
410410
rv = pm.MvNormal.dist(mu=[1, 2, 3], cov=np.eye(3), size=(5, 4))

0 commit comments

Comments
 (0)