|
21 | 21 | import warnings
|
22 | 22 |
|
23 | 23 | from functools import partial
|
24 |
| -from typing import Optional, Sequence, Tuple, Union |
| 24 | +from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union |
25 | 25 |
|
26 | 26 | import numpy as np
|
27 | 27 |
|
28 | 28 | from aesara.graph.basic import Constant, Variable
|
29 | 29 | from aesara.tensor.var import TensorVariable
|
| 30 | +from typing_extensions import TypeAlias |
30 | 31 |
|
31 | 32 | from pymc.aesaraf import change_rv_size, pandas_to_array
|
32 | 33 | from pymc.exceptions import ShapeError, ShapeWarning
|
@@ -412,19 +413,31 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
|
412 | 413 | return [np.broadcast_to(o, to_shape) for o in samples]
|
413 | 414 |
|
414 | 415 |
|
| 416 | +# Workaround to annotate the Ellipsis type, posted by the BDFL himself. |
| 417 | +# See https://github.com/python/typing/issues/684#issuecomment-548203158 |
| 418 | +if TYPE_CHECKING: |
| 419 | + from enum import Enum |
| 420 | + |
| 421 | + class ellipsis(Enum): |
| 422 | + Ellipsis = "..." |
| 423 | + |
| 424 | + Ellipsis = ellipsis.Ellipsis |
| 425 | +else: |
| 426 | + ellipsis = type(Ellipsis) |
| 427 | + |
415 | 428 | # User-provided can be lazily specified as scalars
|
416 |
| -Shape = Union[int, TensorVariable, Sequence[Union[int, TensorVariable, type(Ellipsis)]]] |
417 |
| -Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]] |
418 |
| -Size = Union[int, TensorVariable, Sequence[Union[int, TensorVariable]]] |
| 429 | +Shape: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, TensorVariable, ellipsis]]] |
| 430 | +Dims: TypeAlias = Union[str, Sequence[Optional[Union[str, ellipsis]]]] |
| 431 | +Size: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, TensorVariable]]] |
419 | 432 |
|
420 | 433 | # After conversion to vectors
|
421 |
| -WeakShape = Union[TensorVariable, Tuple[Union[int, TensorVariable, type(Ellipsis)], ...]] |
422 |
| -WeakDims = Tuple[Union[str, None, type(Ellipsis)], ...] |
| 434 | +WeakShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, TensorVariable, ellipsis], ...]] |
| 435 | +WeakDims: TypeAlias = Tuple[Optional[Union[str, ellipsis]], ...] |
423 | 436 |
|
424 | 437 | # After Ellipsis were substituted
|
425 |
| -StrongShape = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]] |
426 |
| -StrongDims = Sequence[Union[str, None]] |
427 |
| -StrongSize = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]] |
| 438 | +StrongShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]] |
| 439 | +StrongDims: TypeAlias = Sequence[Optional[str]] |
| 440 | +StrongSize: TypeAlias = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]] |
428 | 441 |
|
429 | 442 |
|
430 | 443 | def convert_dims(dims: Optional[Dims]) -> Optional[WeakDims]:
|
|
0 commit comments