1
1
import logging
2
+ from collections .abc import Callable
2
3
from typing import cast
3
4
4
- from pytensor .graph import FunctionGraph , Apply
5
+ from pytensor import Variable
6
+ from pytensor .graph import Apply , FunctionGraph
5
7
from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
6
8
from pytensor .tensor .basic import TensorVariable , diagonal , swapaxes
7
9
from pytensor .tensor .blas import Dot22
@@ -325,7 +327,9 @@ def local_log_prod_sqr(fgraph, node):
325
327
326
328
@register_specialize
327
329
@node_rewriter ([Blockwise ])
328
- def local_lift_through_linalg (fgraph : FunctionGraph , node : Apply ):
330
+ def local_lift_through_linalg (
331
+ fgraph : FunctionGraph , node : Apply
332
+ ) -> list [Variable ] | None :
329
333
"""
330
334
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
331
335
that join matrices (KroneckerProduct, BlockDiagonal).
@@ -345,7 +349,7 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
345
349
346
350
Returns
347
351
-------
348
- res: list of Variable, optional
352
+ list of Variable, optional
349
353
List of optimized variables, or None if no optimization was performed
350
354
"""
351
355
@@ -362,15 +366,15 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
362
366
input_matrices = y .owner .inputs
363
367
364
368
if isinstance (outer_op .core_op , MatrixInverse ):
365
- outer_f = inv
369
+ outer_f = cast ( Callable , inv )
366
370
elif isinstance (outer_op .core_op , Cholesky ):
367
- outer_f = cholesky
371
+ outer_f = cast ( Callable , cholesky )
368
372
elif isinstance (outer_op .core_op , MatrixPinv ):
369
- outer_f = pinv
373
+ outer_f = cast ( Callable , pinv )
370
374
else :
371
375
raise NotImplementedError
372
376
373
- inner_matrices = [outer_f (m ) for m in input_matrices ]
377
+ inner_matrices = [cast ( TensorVariable , outer_f (m ) ) for m in input_matrices ]
374
378
375
379
if isinstance (y .owner .op , KroneckerProduct ):
376
380
return [kron (* inner_matrices )]
0 commit comments