Skip to content

Commit 06fa1c3

Browse files
Add return type hint
1 parent ccd8ec7 commit 06fa1c3

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

pytensor/tensor/rewriting/linalg.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import logging
2+
from collections.abc import Callable
23
from typing import cast
34

4-
from pytensor.graph import FunctionGraph, Apply
5+
from pytensor import Variable
6+
from pytensor.graph import Apply, FunctionGraph
57
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
68
from pytensor.tensor.basic import TensorVariable, diagonal
79
from pytensor.tensor.blas import Dot22
@@ -320,7 +322,9 @@ def local_log_prod_sqr(fgraph, node):
320322

321323
@register_specialize
322324
@node_rewriter([Blockwise])
323-
def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
325+
def local_lift_through_linalg(
326+
fgraph: FunctionGraph, node: Apply
327+
) -> list[Variable] | None:
324328
"""
325329
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
326330
that join matrices (KroneckerProduct, BlockDiagonal).
@@ -340,7 +344,7 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
340344
341345
Returns
342346
-------
343-
res: list of Variable, optional
347+
list of Variable, optional
344348
List of optimized variables, or None if no optimization was performed
345349
"""
346350

@@ -357,15 +361,15 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
357361
input_matrices = y.owner.inputs
358362

359363
if isinstance(outer_op.core_op, MatrixInverse):
360-
outer_f = inv
364+
outer_f = cast(Callable, inv)
361365
elif isinstance(outer_op.core_op, Cholesky):
362-
outer_f = cholesky
366+
outer_f = cast(Callable, cholesky)
363367
elif isinstance(outer_op.core_op, MatrixPinv):
364-
outer_f = pinv
368+
outer_f = cast(Callable, pinv)
365369
else:
366370
raise NotImplementedError
367371

368-
inner_matrices = [outer_f(m) for m in input_matrices]
372+
inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices]
369373

370374
if isinstance(y.owner.op, KroneckerProduct):
371375
return [kron(*inner_matrices)]

0 commit comments

Comments
 (0)