|
1 | 1 | import logging
|
2 | 2 | from typing import cast
|
3 | 3 |
|
| 4 | +from pytensor.graph import FunctionGraph, Apply |
4 | 5 | from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
|
5 | 6 | from pytensor.tensor.basic import TensorVariable, diagonal
|
6 | 7 | from pytensor.tensor.blas import Dot22
|
@@ -317,37 +318,32 @@ def local_log_prod_sqr(fgraph, node):
|
317 | 318 | # returns the sign of the prod multiplication.
|
318 | 319 |
|
319 | 320 |
|
320 |
| -def local_inv_kron_to_kron_inv(fgraph, node): |
321 |
| - # check if we have a kron |
322 |
| - # check if parent node is an inv |
323 |
| - # if yes, replace with kron(inv, inv) |
324 |
| - |
325 |
| - pass |
326 |
| - |
327 |
| - |
328 |
| -def local_chol_kron_to_kron_chol(fgraph, node): |
329 |
| - # check if we have a kron |
330 |
| - # check if parent node is a cholesky |
331 |
| - # if yes, replace with kron(cholesky, cholesky) |
332 |
| - |
333 |
| - pass |
334 |
| - |
335 |
| - |
336 | 321 | @register_specialize
|
337 | 322 | @node_rewriter([Blockwise])
|
338 |
| -def local_lift_through_linalg(fgraph, node): |
| 323 | +def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply): |
339 | 324 | """
|
340 |
| - Rewrite a graph like Inv(BlockDiag([A, B, C])) to BlockDiag([Inv(A), Inv(B), Inv(C)]) |
| 325 | + Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops |
| 326 | + that join matrices (KroneckerProduct, BlockDiagonal). |
| 327 | +
|
| 328 | + This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix |
| 329 | + operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker |
| 330 | + product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This |
| 331 | + reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component |
| 332 | + matrices. |
341 | 333 |
|
342 | 334 | Parameters
|
343 | 335 | ----------
|
344 |
| - fgraph |
345 |
| - node |
| 336 | + fgraph: FunctionGraph |
| 337 | + Function graph being optimized |
| 338 | + node: Apply |
| 339 | + Node of the function graph to be optimized |
346 | 340 |
|
347 | 341 | Returns
|
348 | 342 | -------
|
349 |
| -
|
| 343 | + res: list of Variable, optional |
| 344 | + List of optimized variables, or None if no optimization was performed |
350 | 345 | """
|
| 346 | + |
351 | 347 | # TODO: Simplify this if we end up Blockwising KroneckerProduct
|
352 | 348 | if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
|
353 | 349 | y = node.inputs[0]
|
|
0 commit comments