Skip to content

Commit 2bb0c7c

Browse files
Cleanup trace initialization and type hints (#5420)
* Rename local `trace` variable to `mtrace` To reduce type confusion. * Tighten type hints related to `BaseTrace`/`MultiTrace` * Tighten `tune` to always be `int` The code largely assumed it to be `int`, even though it was often advertised as `int | None`. * Consolidate trace backend initialization With this change, the `_iter_sample` and `_prepare_iter_population` take the `tune` into account for the expected length of traces. This was not the case beforehand and I don't understand why it worked. * Add mypy config and more type hints Also fixes an incorrectly named kwarg in some logger calls.
1 parent 2f8f110 commit 2bb0c7c

File tree

4 files changed

+187
-147
lines changed

4 files changed

+187
-147
lines changed

pymc/distributions/distribution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from aesara.tensor.random.op import RandomVariable
3333
from aesara.tensor.random.var import RandomStateSharedVariable
3434
from aesara.tensor.var import TensorVariable
35+
from typing_extensions import TypeAlias
3536

3637
from pymc.aesaraf import change_rv_size
3738
from pymc.distributions.shape_utils import (
@@ -61,7 +62,7 @@
6162
"NoDistribution",
6263
]
6364

64-
DIST_PARAMETER_TYPES = Union[np.ndarray, int, float, TensorVariable]
65+
DIST_PARAMETER_TYPES: TypeAlias = Union[np.ndarray, int, float, TensorVariable]
6566

6667
vectorized_ppc = contextvars.ContextVar(
6768
"vectorized_ppc", default=None

pymc/distributions/shape_utils.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
import warnings
2222

2323
from functools import partial
24-
from typing import Optional, Sequence, Tuple, Union
24+
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
2525

2626
import numpy as np
2727

2828
from aesara.graph.basic import Constant, Variable
2929
from aesara.tensor.var import TensorVariable
30+
from typing_extensions import TypeAlias
3031

3132
from pymc.aesaraf import change_rv_size, pandas_to_array
3233
from pymc.exceptions import ShapeError, ShapeWarning
@@ -412,19 +413,31 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
412413
return [np.broadcast_to(o, to_shape) for o in samples]
413414

414415

416+
# Workaround to annotate the Ellipsis type, posted by the BDFL himself.
417+
# See https://github.com/python/typing/issues/684#issuecomment-548203158
418+
if TYPE_CHECKING:
419+
from enum import Enum
420+
421+
class ellipsis(Enum):
422+
Ellipsis = "..."
423+
424+
Ellipsis = ellipsis.Ellipsis
425+
else:
426+
ellipsis = type(Ellipsis)
427+
415428
# User-provided can be lazily specified as scalars
416-
Shape = Union[int, TensorVariable, Sequence[Union[int, TensorVariable, type(Ellipsis)]]]
417-
Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]]
418-
Size = Union[int, TensorVariable, Sequence[Union[int, TensorVariable]]]
429+
Shape: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, TensorVariable, ellipsis]]]
430+
Dims: TypeAlias = Union[str, Sequence[Optional[Union[str, ellipsis]]]]
431+
Size: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, TensorVariable]]]
419432

420433
# After conversion to vectors
421-
WeakShape = Union[TensorVariable, Tuple[Union[int, TensorVariable, type(Ellipsis)], ...]]
422-
WeakDims = Tuple[Union[str, None, type(Ellipsis)], ...]
434+
WeakShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, TensorVariable, ellipsis], ...]]
435+
WeakDims: TypeAlias = Tuple[Optional[Union[str, ellipsis]], ...]
423436

424437
# After Ellipsis were substituted
425-
StrongShape = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
426-
StrongDims = Sequence[Union[str, None]]
427-
StrongSize = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
438+
StrongShape: TypeAlias = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
439+
StrongDims: TypeAlias = Sequence[Optional[str]]
440+
StrongSize: TypeAlias = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
428441

429442

430443
def convert_dims(dims: Optional[Dims]) -> Optional[WeakDims]:

0 commit comments

Comments
 (0)