Skip to content

Commit ccd8ec7

Browse files
Improve docstring
1 parent 0858f47 commit ccd8ec7

File tree

3 files changed

+17
-25
lines changed

3 files changed

+17
-25
lines changed

pytensor/tensor/nlinalg.py

-2
Original file line numberDiff line numberDiff line change
@@ -1017,8 +1017,6 @@ class KroneckerProduct(OpFromGraph):
10171017
Wrapper Op for Kronecker graphs
10181018
"""
10191019

1020-
...
1021-
10221020

10231021
def kron(a, b):
10241022
"""Kronecker product.

pytensor/tensor/rewriting/linalg.py

+17-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import cast
33

4+
from pytensor.graph import FunctionGraph, Apply
45
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
56
from pytensor.tensor.basic import TensorVariable, diagonal
67
from pytensor.tensor.blas import Dot22
@@ -317,37 +318,32 @@ def local_log_prod_sqr(fgraph, node):
317318
# returns the sign of the prod multiplication.
318319

319320

320-
def local_inv_kron_to_kron_inv(fgraph, node):
321-
# check if we have a kron
322-
# check if parent node is an inv
323-
# if yes, replace with kron(inv, inv)
324-
325-
pass
326-
327-
328-
def local_chol_kron_to_kron_chol(fgraph, node):
329-
# check if we have a kron
330-
# check if parent node is a cholesky
331-
# if yes, replace with kron(cholesky, cholesky)
332-
333-
pass
334-
335-
336321
@register_specialize
337322
@node_rewriter([Blockwise])
338-
def local_lift_through_linalg(fgraph, node):
323+
def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
339324
"""
340-
Rewrite a graph like Inv(BlockDiag([A, B, C])) to BlockDiag([Inv(A), Inv(B), Inv(C)])
325+
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
326+
that join matrices (KroneckerProduct, BlockDiagonal).
327+
328+
This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix
329+
operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker
330+
product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This
331+
reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component
332+
matrices.
341333
342334
Parameters
343335
----------
344-
fgraph
345-
node
336+
fgraph: FunctionGraph
337+
Function graph being optimized
338+
node: Apply
339+
Node of the function graph to be optimized
346340
347341
Returns
348342
-------
349-
343+
res: list of Variable, optional
344+
List of optimized variables, or None if no optimization was performed
350345
"""
346+
351347
# TODO: Simplify this if we end up Blockwising KroneckerProduct
352348
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
353349
y = node.inputs[0]

tests/tensor/rewriting/test_linalg.py

-2
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,4 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
386386
]
387387
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
388388

389-
f2(*test_vals)
390-
391389
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)

0 commit comments

Comments
 (0)