52
52
from aesara .scalar .basic import Cast
53
53
from aesara .tensor .elemwise import Elemwise
54
54
from aesara .tensor .random .op import RandomVariable
55
+ from aesara .tensor .random .var import (
56
+ RandomGeneratorSharedVariable ,
57
+ RandomStateSharedVariable ,
58
+ )
55
59
from aesara .tensor .shape import SpecifyShape
56
60
from aesara .tensor .sharedvar import SharedVariable
57
61
from aesara .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
60
64
from pymc .exceptions import ShapeError
61
65
from pymc .vartypes import continuous_types , isgenerator , typefilter
62
66
63
- PotentialShapeType = Union [
64
- int , np .ndarray , Tuple [Union [int , Variable ], ...], List [Union [int , Variable ]], Variable
65
- ]
67
+ PotentialShapeType = Union [int , np .ndarray , Sequence [Union [int , Variable ]], TensorVariable ]
66
68
67
69
68
70
__all__ = [
@@ -165,6 +167,7 @@ def change_rv_size(
165
167
new_size = (new_size ,)
166
168
167
169
# Extract the RV node that is to be resized, together with its inputs, name and tag
170
+ assert rw .owner .op is not None
168
171
if isinstance (rv .owner .op , SpecifyShape ):
169
172
rv = rv .owner .inputs [0 ]
170
173
rv_node = rv .owner
@@ -894,18 +897,14 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
894
897
)
895
898
896
899
897
- def find_rng_nodes (variables : Iterable [TensorVariable ]):
900
+ def find_rng_nodes (
901
+ variables : Iterable [Variable ],
902
+ ) -> List [Union [RandomStateSharedVariable , RandomGeneratorSharedVariable ]]:
898
903
"""Return RNG variables in a graph"""
899
904
return [
900
905
node
901
906
for node in graph_inputs (variables )
902
- if isinstance (
903
- node ,
904
- (
905
- at .random .var .RandomStateSharedVariable ,
906
- at .random .var .RandomGeneratorSharedVariable ,
907
- ),
908
- )
907
+ if isinstance (node , (RandomStateSharedVariable , RandomGeneratorSharedVariable ))
909
908
]
910
909
911
910
@@ -921,6 +920,7 @@ def reseed_rngs(
921
920
np .random .PCG64 (sub_seed ) for sub_seed in np .random .SeedSequence (seed ).spawn (len (rngs ))
922
921
]
923
922
for rng , bit_generator in zip (rngs , bit_generators ):
923
+ new_rng : Union [np .random .RandomState , np .random .Generator ]
924
924
if isinstance (rng , at .random .var .RandomStateSharedVariable ):
925
925
new_rng = np .random .RandomState (bit_generator )
926
926
else :
@@ -980,6 +980,9 @@ def compile_pymc(
980
980
and isinstance (var .owner .op , (RandomVariable , MeasurableVariable ))
981
981
and var not in inputs
982
982
):
983
+ # All nodes in `vars_between(inputs, outputs)` have owners.
984
+ # But mypy doesn't know, so we just assert it:
985
+ assert random_var .owner .op is not None
983
986
if isinstance (random_var .owner .op , RandomVariable ):
984
987
rng = random_var .owner .inputs [0 ]
985
988
if not hasattr (rng , "default_update" ):
0 commit comments