Skip to content

Commit e703221

Browse files
committed
Remove ellipsis functionality in shape and dims
1 parent 4fd2916 commit e703221

File tree

3 files changed

+21
-119
lines changed

3 files changed

+21
-119
lines changed

pymc/distributions/distribution.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from abc import ABCMeta
2121
from functools import singledispatch
22-
from typing import Callable, Optional, Sequence, Tuple, Union, cast
22+
from typing import Callable, Optional, Sequence, Tuple, Union
2323

2424
import aesara
2525
import numpy as np
@@ -39,7 +39,6 @@
3939
Shape,
4040
Size,
4141
StrongShape,
42-
WeakDims,
4342
convert_dims,
4443
convert_shape,
4544
convert_size,
@@ -158,7 +157,7 @@ def _make_rv_and_resize_shape(
158157
observed,
159158
args,
160159
**kwargs,
161-
) -> Tuple[Variable, Optional[WeakDims], Optional[Union[np.ndarray, Variable]], StrongShape]:
160+
) -> Tuple[Variable, Optional[Dims], Optional[Union[np.ndarray, Variable]], StrongShape]:
162161
"""Creates the RV and processes dims or observed to determine a resize shape."""
163162
# Create the RV without dims information, because that's not something tracked at the Aesara level.
164163
# If necessary we'll later replicate to a different size implied by already known dims.
@@ -173,9 +172,6 @@ def _make_rv_and_resize_shape(
173172
if dims is not None:
174173
if dims_can_resize:
175174
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
176-
elif Ellipsis in dims:
177-
# Replace ... with None entries to match the actual dimensionality.
178-
dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual]
179175
elif observed is not None:
180176
resize_shape, observed = resize_from_observed(observed, ndim_actual)
181177
return rv_out, dims, observed, resize_shape
@@ -305,9 +301,6 @@ def dist(
305301
The inputs to the `RandomVariable` `Op`.
306302
shape : int, tuple, Variable, optional
307303
A tuple of sizes for each dimension of the new RV.
308-
309-
An Ellipsis (...) may be inserted in the last position to short-hand refer to
310-
all the dimensions that the RV would get if no shape/size/dims were passed at all.
311304
**kwargs
312305
Keyword arguments that will be forwarded to the Aesara RV Op.
313306
Most prominently: ``size`` or ``dtype``.
@@ -350,11 +343,6 @@ def dist(
350343
# This is not necessarily the final result.
351344
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
352345

353-
# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
354-
if shape is not None and Ellipsis in shape:
355-
replicate_shape = cast(StrongShape, shape[:-1])
356-
rv_out = change_rv_size(rv=rv_out, new_size=replicate_shape, expand=True)
357-
358346
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
359347
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
360348
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
@@ -508,8 +496,6 @@ def dist(
508496
The inputs to the `RandomVariable` `Op`.
509497
shape : int, tuple, Variable, optional
510498
A tuple of sizes for each dimension of the new RV.
511-
An Ellipsis (...) may be inserted in the last position to short-hand refer to
512-
all the dimensions that the RV would get if no shape/size/dims were passed at all.
513499
size : int, tuple, Variable, optional
514500
For creating the RV like in Aesara/NumPy.
515501
@@ -550,11 +536,6 @@ def dist(
550536
# This is not necessarily the final result.
551537
graph = cls.rv_op(*dist_params, size=create_size, **kwargs)
552538

553-
# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
554-
if shape is not None and Ellipsis in shape:
555-
replicate_shape = cast(StrongShape, shape[:-1])
556-
graph = cls.change_size(rv=graph, new_size=replicate_shape, expand=True)
557-
558539
# TODO: Create new attr error stating that these are not available for DerivedDistribution
559540
# rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
560541
# rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")

pymc/distributions/shape_utils.py

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
samples from probability distributions for stochastic nodes in PyMC.
1919
"""
2020

21-
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union, cast
21+
from typing import Optional, Sequence, Tuple, Union, cast
2222

2323
import numpy as np
2424

@@ -409,34 +409,18 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
409409
return [np.broadcast_to(o, to_shape) for o in samples]
410410

411411

412-
# Workaround to annotate the Ellipsis type, posted by the BDFL himself.
413-
# See https://github.com/python/typing/issues/684#issuecomment-548203158
414-
if TYPE_CHECKING:
415-
from enum import Enum
416-
417-
class ellipsis(Enum):
418-
Ellipsis = "..."
419-
420-
Ellipsis = ellipsis.Ellipsis
421-
else:
422-
ellipsis = type(Ellipsis)
423-
424412
# User-provided can be lazily specified as scalars
425-
Shape: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, Variable, ellipsis]]]
426-
Dims: TypeAlias = Union[str, Sequence[Optional[Union[str, ellipsis]]]]
413+
Shape: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, Variable]]]
414+
Dims: TypeAlias = Union[str, Sequence[Optional[str]]]
427415
Size: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, Variable]]]
428416

429417
# After conversion to vectors
430-
WeakShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, Variable, ellipsis], ...]]
431-
WeakDims: TypeAlias = Tuple[Optional[Union[str, ellipsis]], ...]
432-
433-
# After Ellipsis were substituted
434418
StrongShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, Variable], ...]]
435419
StrongDims: TypeAlias = Sequence[Optional[str]]
436420
StrongSize: TypeAlias = Union[TensorVariable, Tuple[Union[int, Variable], ...]]
437421

438422

439-
def convert_dims(dims: Optional[Dims]) -> Optional[WeakDims]:
423+
def convert_dims(dims: Optional[Dims]) -> Optional[StrongDims]:
440424
"""Process a user-provided dims variable into None or a valid dims tuple."""
441425
if dims is None:
442426
return None
@@ -448,13 +432,10 @@ def convert_dims(dims: Optional[Dims]) -> Optional[WeakDims]:
448432
else:
449433
raise ValueError(f"The `dims` parameter must be a tuple, str or list. Actual: {type(dims)}")
450434

451-
if any(d == Ellipsis for d in dims[:-1]):
452-
raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}")
453-
454435
return dims
455436

456437

457-
def convert_shape(shape: Shape) -> Optional[WeakShape]:
438+
def convert_shape(shape: Shape) -> Optional[StrongShape]:
458439
"""Process a user-provided shape variable into None or a valid shape object."""
459440
if shape is None:
460441
return None
@@ -468,10 +449,6 @@ def convert_shape(shape: Shape) -> Optional[WeakShape]:
468449
raise ValueError(
469450
f"The `shape` parameter must be a tuple, TensorVariable, int or list. Actual: {type(shape)}"
470451
)
471-
if isinstance(shape, tuple) and any(s == Ellipsis for s in shape[:-1]):
472-
raise ValueError(
473-
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
474-
)
475452

476453
return shape
477454

@@ -490,19 +467,17 @@ def convert_size(size: Size) -> Optional[StrongSize]:
490467
raise ValueError(
491468
f"The `size` parameter must be a tuple, TensorVariable, int or list. Actual: {type(size)}"
492469
)
493-
if isinstance(size, tuple) and Ellipsis in size:
494-
raise ValueError(f"The `size` parameter cannot contain an Ellipsis. Actual: {size}")
495470

496471
return size
497472

498473

499-
def resize_from_dims(dims: WeakDims, ndim_implied: int, model) -> Tuple[StrongSize, StrongDims]:
474+
def resize_from_dims(dims: Dims, ndim_implied: int, model) -> Tuple[StrongSize, StrongDims]:
500475
"""Determines a potential resize shape from a `dims` tuple.
501476
502477
Parameters
503478
----------
504479
dims : array-like
505-
A vector of dimension names, None or Ellipsis.
480+
A vector of dimension names or None.
506481
ndim_implied : int
507482
Number of RV dimensions that were implied from its inputs alone.
508483
model : pm.Model
@@ -515,11 +490,6 @@ def resize_from_dims(dims: WeakDims, ndim_implied: int, model) -> Tuple[StrongSi
515490
dims : tuple of (str or None)
516491
Names or None for all dimensions after resizing.
517492
"""
518-
if Ellipsis in dims:
519-
# Auto-complete the dims tuple to the full length.
520-
# We don't have a way to know the names of implied
521-
# dimensions, so they will be `None`.
522-
dims = (*dims[:-1], *[None] * ndim_implied)
523493
sdims = cast(StrongDims, dims)
524494

525495
ndim_resize = len(sdims) - ndim_implied
@@ -565,7 +535,7 @@ def resize_from_observed(
565535

566536

567537
def find_size(
568-
shape: Optional[WeakShape],
538+
shape: Optional[StrongShape],
569539
size: Optional[StrongSize],
570540
ndim_supp: int,
571541
) -> Tuple[Optional[StrongSize], Optional[int], Optional[int], int]:
@@ -598,16 +568,9 @@ def find_size(
598568
create_size: Optional[StrongSize] = None
599569

600570
if shape is not None:
601-
if Ellipsis in shape:
602-
# Ellipsis short-hands all implied dimensions. Therefore
603-
# we don't know how many dimensions to expect.
604-
ndim_expected = ndim_batch = None
605-
# Create the RV with its implied shape and resize later
606-
create_size = None
607-
else:
608-
ndim_expected = len(tuple(shape))
609-
ndim_batch = ndim_expected - ndim_supp
610-
create_size = tuple(shape)[:ndim_batch]
571+
ndim_expected = len(tuple(shape))
572+
ndim_batch = ndim_expected - ndim_supp
573+
create_size = tuple(shape)[:ndim_batch]
611574
elif size is not None:
612575
ndim_expected = ndim_supp + len(tuple(size))
613576
ndim_batch = ndim_expected - ndim_supp

pymc/tests/test_shape_handling.py

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,7 @@ class TestShapeDimsSize:
211211
[
212212
"implicit",
213213
"shape",
214-
"shape...",
215214
"dims",
216-
"dims...",
217215
"size",
218216
],
219217
)
@@ -249,65 +247,36 @@ def test_param_and_batch_shape_combos(
249247
if parametrization == "shape":
250248
rv = pm.Normal("rv", mu=mu, shape=batch_shape + param_shape)
251249
assert rv.eval().shape == expected_shape
252-
elif parametrization == "shape...":
253-
rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
254-
assert rv.eval().shape == batch_shape + param_shape
255250
elif parametrization == "dims":
256251
rv = pm.Normal("rv", mu=mu, dims=batch_dims + param_dims)
257252
assert rv.eval().shape == expected_shape
258-
elif parametrization == "dims...":
259-
rv = pm.Normal("rv", mu=mu, dims=(*batch_dims, ...))
260-
n_size = len(batch_shape)
261-
n_implied = len(param_shape)
262-
ndim = n_size + n_implied
263-
assert len(pmodel.RV_dims["rv"]) == ndim, pmodel.RV_dims
264-
assert len(pmodel.RV_dims["rv"][:n_size]) == len(batch_dims)
265-
assert len(pmodel.RV_dims["rv"][n_size:]) == len(param_dims)
266-
if n_implied > 0:
267-
assert pmodel.RV_dims["rv"][-1] is None
268253
elif parametrization == "size":
269254
rv = pm.Normal("rv", mu=mu, size=batch_shape + param_shape)
270255
assert rv.eval().shape == expected_shape
271256
else:
272257
raise NotImplementedError("Invalid test case parametrization.")
273258

274-
@pytest.mark.parametrize("ellipsis_in", ["none", "shape", "dims", "both"])
275-
def test_simultaneous_shape_and_dims(self, ellipsis_in):
259+
def test_simultaneous_shape_and_dims(self):
276260
with pm.Model() as pmodel:
277261
x = pm.ConstantData("x", [1, 2, 3], dims="ddata")
278262

279-
if ellipsis_in == "none":
280-
# The shape and dims tuples correspond to each other.
281-
# Note: No checks are performed that implied shape (x), shape and dims actually match.
282-
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", "ddata"))
283-
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
284-
elif ellipsis_in == "shape":
285-
y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", "ddata"))
286-
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
287-
elif ellipsis_in == "dims":
288-
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", ...))
289-
assert pmodel.RV_dims["y"] == ("dshape", None)
290-
elif ellipsis_in == "both":
291-
y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", ...))
292-
assert pmodel.RV_dims["y"] == ("dshape", None)
263+
# The shape and dims tuples correspond to each other.
264+
# Note: No checks are performed that implied shape (x), shape and dims actually match.
265+
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", "ddata"))
266+
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
293267

294268
assert "dshape" in pmodel.dim_lengths
295269
assert y.eval().shape == (2, 3)
296270

297-
@pytest.mark.parametrize("with_dims_ellipsis", [False, True])
298-
def test_simultaneous_size_and_dims(self, with_dims_ellipsis):
271+
def test_simultaneous_size_and_dims(self):
299272
with pm.Model() as pmodel:
300273
x = pm.ConstantData("x", [1, 2, 3], dims="ddata")
301274
assert "ddata" in pmodel.dim_lengths
302275

303276
# Size does not include support dims, so this test must use a dist with support dims.
304277
kwargs = dict(name="y", size=(2, 3), mu=at.ones((3, 4)), cov=at.eye(4))
305-
if with_dims_ellipsis:
306-
y = pm.MvNormal(**kwargs, dims=("dsize", ...))
307-
assert pmodel.RV_dims["y"] == ("dsize", None, None)
308-
else:
309-
y = pm.MvNormal(**kwargs, dims=("dsize", "ddata", "dsupport"))
310-
assert pmodel.RV_dims["y"] == ("dsize", "ddata", "dsupport")
278+
y = pm.MvNormal(**kwargs, dims=("dsize", "ddata", "dsupport"))
279+
assert pmodel.RV_dims["y"] == ("dsize", "ddata", "dsupport")
311280

312281
assert "dsize" in pmodel.dim_lengths
313282
assert y.eval().shape == (2, 3, 4)
@@ -382,7 +351,6 @@ def test_dist_api_works(self):
382351
pm.Normal.dist(mu=mu, dims=("town",))
383352
assert pm.Normal.dist(mu=mu, shape=(3,)).eval().shape == (3,)
384353
assert pm.Normal.dist(mu=mu, shape=(5, 3)).eval().shape == (5, 3)
385-
assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
386354
assert pm.Normal.dist(mu=mu, size=(3,)).eval().shape == (3,)
387355
assert pm.Normal.dist(mu=mu, size=(4, 3)).eval().shape == (4, 3)
388356

@@ -408,10 +376,6 @@ def test_mvnormal_shape_size_difference(self):
408376
assert rv.ndim == 3
409377
assert tuple(rv.shape.eval()) == (5, 4, 3)
410378

411-
rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...))
412-
assert rv.ndim == 5
413-
assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2)
414-
415379
rv = pm.MvNormal.dist(mu=[1, 2, 3], cov=np.eye(3), size=(5, 4))
416380
assert tuple(rv.shape.eval()) == (5, 4, 3)
417381

@@ -422,22 +386,16 @@ def test_convert_dims(self):
422386
assert convert_dims(dims="town") == ("town",)
423387
with pytest.raises(ValueError, match="must be a tuple, str or list"):
424388
convert_dims(3)
425-
with pytest.raises(ValueError, match="may only appear in the last position"):
426-
convert_dims(dims=(..., "town"))
427389

428390
def test_convert_shape(self):
429391
assert convert_shape(5) == (5,)
430392
with pytest.raises(ValueError, match="tuple, TensorVariable, int or list"):
431393
convert_shape(shape="notashape")
432-
with pytest.raises(ValueError, match="may only appear in the last position"):
433-
convert_shape(shape=(3, ..., 2))
434394

435395
def test_convert_size(self):
436396
assert convert_size(7) == (7,)
437397
with pytest.raises(ValueError, match="tuple, TensorVariable, int or list"):
438398
convert_size(size="notasize")
439-
with pytest.raises(ValueError, match="cannot contain"):
440-
convert_size(size=(3, ...))
441399

442400
def test_lazy_flavors(self):
443401
assert pm.Uniform.dist(2, [4, 5], size=[3, 2]).eval().shape == (3, 2)

0 commit comments

Comments
 (0)