Skip to content

Commit d32640e

Browse files
Add return type hint
1 parent b6c9692 commit d32640e

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

pytensor/tensor/rewriting/linalg.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import logging
2+
from collections.abc import Callable
23
from typing import cast
34

4-
from pytensor.graph import FunctionGraph, Apply
5+
from pytensor import Variable
6+
from pytensor.graph import Apply, FunctionGraph
57
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
68
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
79
from pytensor.tensor.blas import Dot22
@@ -325,7 +327,9 @@ def local_log_prod_sqr(fgraph, node):
325327

326328
@register_specialize
327329
@node_rewriter([Blockwise])
328-
def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
330+
def local_lift_through_linalg(
331+
fgraph: FunctionGraph, node: Apply
332+
) -> list[Variable] | None:
329333
"""
330334
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
331335
that join matrices (KroneckerProduct, BlockDiagonal).
@@ -345,7 +349,7 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
345349
346350
Returns
347351
-------
348-
res: list of Variable, optional
352+
list of Variable, optional
349353
List of optimized variables, or None if no optimization was performed
350354
"""
351355

@@ -362,15 +366,15 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
362366
input_matrices = y.owner.inputs
363367

364368
if isinstance(outer_op.core_op, MatrixInverse):
365-
outer_f = inv
369+
outer_f = cast(Callable, inv)
366370
elif isinstance(outer_op.core_op, Cholesky):
367-
outer_f = cholesky
371+
outer_f = cast(Callable, cholesky)
368372
elif isinstance(outer_op.core_op, MatrixPinv):
369-
outer_f = pinv
373+
outer_f = cast(Callable, pinv)
370374
else:
371375
raise NotImplementedError
372376

373-
inner_matrices = [outer_f(m) for m in input_matrices]
377+
inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices]
374378

375379
if isinstance(y.owner.op, KroneckerProduct):
376380
return [kron(*inner_matrices)]

0 commit comments

Comments
 (0)