Skip to content

Commit 98d767a

Browse files
Fix more type issues
1 parent 22cec6b commit 98d767a

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

pymc/aesaraf.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
from aesara.scalar.basic import Cast
5353
from aesara.tensor.elemwise import Elemwise
5454
from aesara.tensor.random.op import RandomVariable
55+
from aesara.tensor.random.var import (
56+
RandomGeneratorSharedVariable,
57+
RandomStateSharedVariable,
58+
)
5559
from aesara.tensor.shape import SpecifyShape
5660
from aesara.tensor.sharedvar import SharedVariable
5761
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
@@ -60,9 +64,7 @@
6064
from pymc.exceptions import ShapeError
6165
from pymc.vartypes import continuous_types, isgenerator, typefilter
6266

63-
PotentialShapeType = Union[
64-
int, np.ndarray, Tuple[Union[int, Variable], ...], List[Union[int, Variable]], Variable
65-
]
67+
PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
6668

6769

6870
__all__ = [
@@ -165,6 +167,7 @@ def change_rv_size(
165167
new_size = (new_size,)
166168

167169
# Extract the RV node that is to be resized, together with its inputs, name and tag
170+
assert rw.owner.op is not None
168171
if isinstance(rv.owner.op, SpecifyShape):
169172
rv = rv.owner.inputs[0]
170173
rv_node = rv.owner
@@ -894,18 +897,14 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
894897
)
895898

896899

897-
def find_rng_nodes(variables: Iterable[TensorVariable]):
900+
def find_rng_nodes(
901+
variables: Iterable[Variable],
902+
) -> List[Union[RandomStateSharedVariable, RandomGeneratorSharedVariable]]:
898903
"""Return RNG variables in a graph"""
899904
return [
900905
node
901906
for node in graph_inputs(variables)
902-
if isinstance(
903-
node,
904-
(
905-
at.random.var.RandomStateSharedVariable,
906-
at.random.var.RandomGeneratorSharedVariable,
907-
),
908-
)
907+
if isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
909908
]
910909

911910

@@ -921,6 +920,7 @@ def reseed_rngs(
921920
np.random.PCG64(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs))
922921
]
923922
for rng, bit_generator in zip(rngs, bit_generators):
923+
new_rng: Union[np.random.RandomState, np.random.Generator]
924924
if isinstance(rng, at.random.var.RandomStateSharedVariable):
925925
new_rng = np.random.RandomState(bit_generator)
926926
else:
@@ -980,6 +980,9 @@ def compile_pymc(
980980
and isinstance(var.owner.op, (RandomVariable, MeasurableVariable))
981981
and var not in inputs
982982
):
983+
# All nodes in `vars_between(inputs, outputs)` have owners.
984+
# But mypy doesn't know, so we just assert it:
985+
assert random_var.owner.op is not None
983986
if isinstance(random_var.owner.op, RandomVariable):
984987
rng = random_var.owner.inputs[0]
985988
if not hasattr(rng, "default_update"):

0 commit comments

Comments
 (0)