Skip to content

Commit 6b22ed5

Browse files
Add return type hints (#5819)
* Add return type hints * Add type hints for rv arguments * Fix more type issues Closes #4880 Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 57654dc commit 6b22ed5

File tree

3 files changed

+29
-24
lines changed

3 files changed

+29
-24
lines changed

pymc/aesaraf.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
from aesara.scalar.basic import Cast
5353
from aesara.tensor.elemwise import Elemwise
5454
from aesara.tensor.random.op import RandomVariable
55+
from aesara.tensor.random.var import (
56+
RandomGeneratorSharedVariable,
57+
RandomStateSharedVariable,
58+
)
5559
from aesara.tensor.shape import SpecifyShape
5660
from aesara.tensor.sharedvar import SharedVariable
5761
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
@@ -60,9 +64,7 @@
6064
from pymc.exceptions import ShapeError
6165
from pymc.vartypes import continuous_types, isgenerator, typefilter
6266

63-
PotentialShapeType = Union[
64-
int, np.ndarray, Tuple[Union[int, Variable], ...], List[Union[int, Variable]], Variable
65-
]
67+
PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
6668

6769

6870
__all__ = [
@@ -165,6 +167,7 @@ def change_rv_size(
165167
new_size = (new_size,)
166168

167169
# Extract the RV node that is to be resized, together with its inputs, name and tag
170+
assert rv.owner.op is not None
168171
if isinstance(rv.owner.op, SpecifyShape):
169172
rv = rv.owner.inputs[0]
170173
rv_node = rv.owner
@@ -894,18 +897,14 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
894897
)
895898

896899

897-
def find_rng_nodes(variables: Iterable[TensorVariable]):
900+
def find_rng_nodes(
901+
variables: Iterable[Variable],
902+
) -> List[Union[RandomStateSharedVariable, RandomGeneratorSharedVariable]]:
898903
"""Return RNG variables in a graph"""
899904
return [
900905
node
901906
for node in graph_inputs(variables)
902-
if isinstance(
903-
node,
904-
(
905-
at.random.var.RandomStateSharedVariable,
906-
at.random.var.RandomGeneratorSharedVariable,
907-
),
908-
)
907+
if isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
909908
]
910909

911910

@@ -921,6 +920,7 @@ def reseed_rngs(
921920
np.random.PCG64(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs))
922921
]
923922
for rng, bit_generator in zip(rngs, bit_generators):
923+
new_rng: Union[np.random.RandomState, np.random.Generator]
924924
if isinstance(rng, at.random.var.RandomStateSharedVariable):
925925
new_rng = np.random.RandomState(bit_generator)
926926
else:
@@ -980,6 +980,9 @@ def compile_pymc(
980980
and isinstance(var.owner.op, (RandomVariable, MeasurableVariable))
981981
and var not in inputs
982982
):
983+
# All nodes in `vars_between(inputs, outputs)` have owners.
984+
# But mypy doesn't know, so we just assert it:
985+
assert random_var.owner.op is not None
983986
if isinstance(random_var.owner.op, RandomVariable):
984987
rng = random_var.owner.inputs[0]
985988
if not hasattr(rng, "default_update"):

pymc/distributions/distribution.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def __new__(
198198
total_size=None,
199199
transform=UNSET,
200200
**kwargs,
201-
) -> RandomVariable:
202-
"""Adds a RandomVariable corresponding to a PyMC distribution to the current model.
201+
) -> TensorVariable:
202+
"""Adds a tensor variable corresponding to a PyMC distribution to the current model.
203203
204204
Note that all remaining kwargs must be compatible with ``.dist()``
205205
@@ -231,8 +231,8 @@ def __new__(
231231
232232
Returns
233233
-------
234-
rv : RandomVariable
235-
The created RV, registered in the Model.
234+
rv : TensorVariable
235+
The created random variable tensor, registered in the Model.
236236
"""
237237

238238
try:
@@ -296,8 +296,8 @@ def dist(
296296
*,
297297
shape: Optional[Shape] = None,
298298
**kwargs,
299-
) -> RandomVariable:
300-
"""Creates a RandomVariable corresponding to the `cls` distribution.
299+
) -> TensorVariable:
300+
"""Creates a tensor variable corresponding to the `cls` distribution.
301301
302302
Parameters
303303
----------
@@ -314,8 +314,8 @@ def dist(
314314
315315
Returns
316316
-------
317-
rv : RandomVariable
318-
The created RV.
317+
rv : TensorVariable
318+
The created random variable tensor.
319319
"""
320320
if "testval" in kwargs:
321321
kwargs.pop("testval")
@@ -653,8 +653,8 @@ def __new__(
653653
name : str
654654
dist_params : Tuple
655655
A sequence of the distribution's parameter. These will be converted into
656-
Aesara tensors internally. These parameters could be other ``RandomVariable``
657-
instances.
656+
Aesara tensors internally. These parameters could be other ``TensorVariable``
657+
instances created from , optionally created via ``RandomVariable`` ``Op``s.
658658
logp : Optional[Callable]
659659
A callable that calculates the log density of some given observed ``value``
660660
conditioned on certain distribution parameter values. It must have the

pymc/distributions/logprob.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939
from pymc.aesaraf import floatX
4040

4141

42-
def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int):
42+
def _get_scaling(
43+
total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int
44+
) -> TensorVariable:
4345
"""
4446
Gets scaling constant for logp.
4547
@@ -288,14 +290,14 @@ def logp(rv: TensorVariable, value) -> TensorVariable:
288290
raise NotImplementedError("PyMC could not infer logp of input variable.") from exc
289291

290292

291-
def logcdf(rv, value):
293+
def logcdf(rv: TensorVariable, value) -> TensorVariable:
292294
"""Return the log-cdf graph of a Random Variable"""
293295

294296
value = at.as_tensor_variable(value, dtype=rv.dtype)
295297
return logcdf_aeppl(rv, value)
296298

297299

298-
def ignore_logprob(rv):
300+
def ignore_logprob(rv: TensorVariable) -> TensorVariable:
299301
"""Return a duplicated variable that is ignored when creating Aeppl logprob graphs
300302
301303
This is used in SymbolicDistributions that use other RVs as inputs but account

0 commit comments

Comments
 (0)