1
1
import logging
2
+ from collections .abc import Callable
2
3
from typing import cast
3
4
4
- from pytensor .graph import FunctionGraph , Apply
5
+ from pytensor import Variable
6
+ from pytensor .graph import Apply , FunctionGraph
5
7
from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
6
8
from pytensor .tensor .basic import TensorVariable , diagonal
7
9
from pytensor .tensor .blas import Dot22
@@ -320,7 +322,9 @@ def local_log_prod_sqr(fgraph, node):
320
322
321
323
@register_specialize
322
324
@node_rewriter ([Blockwise ])
323
- def local_lift_through_linalg (fgraph : FunctionGraph , node : Apply ):
325
+ def local_lift_through_linalg (
326
+ fgraph : FunctionGraph , node : Apply
327
+ ) -> list [Variable ] | None :
324
328
"""
325
329
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
326
330
that join matrices (KroneckerProduct, BlockDiagonal).
@@ -340,7 +344,7 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
340
344
341
345
Returns
342
346
-------
343
- res: list of Variable, optional
347
+ list of Variable, optional
344
348
List of optimized variables, or None if no optimization was performed
345
349
"""
346
350
@@ -357,15 +361,15 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
357
361
input_matrices = y .owner .inputs
358
362
359
363
if isinstance (outer_op .core_op , MatrixInverse ):
360
- outer_f = inv
364
+ outer_f = cast ( Callable , inv )
361
365
elif isinstance (outer_op .core_op , Cholesky ):
362
- outer_f = cholesky
366
+ outer_f = cast ( Callable , cholesky )
363
367
elif isinstance (outer_op .core_op , MatrixPinv ):
364
- outer_f = pinv
368
+ outer_f = cast ( Callable , pinv )
365
369
else :
366
370
raise NotImplementedError
367
371
368
- inner_matrices = [outer_f (m ) for m in input_matrices ]
372
+ inner_matrices = [cast ( TensorVariable , outer_f (m ) ) for m in input_matrices ]
369
373
370
374
if isinstance (y .owner .op , KroneckerProduct ):
371
375
return [kron (* inner_matrices )]
0 commit comments