Skip to content

Commit aac3305

Browse files
Import kron and KroneckerProduct from nlinalg
1 parent d39fb5e commit aac3305

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

pytensor/tensor/nlinalg.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from numpy.core.numeric import normalize_axis_tuple # type: ignore
88

99
from pytensor import scalar as ps
10+
from pytensor.compile.builders import OpFromGraph
1011
from pytensor.gradient import DisconnectedType
1112
from pytensor.graph.basic import Apply
1213
from pytensor.graph.op import Op
@@ -1011,6 +1012,14 @@ def tensorsolve(a, b, axes=None):
10111012
return TensorSolve(axes)(a, b)
10121013

10131014

1015+
class KroneckerProduct(OpFromGraph):
1016+
"""
1017+
Wrapper Op for Kronecker graphs
1018+
"""
1019+
1020+
...
1021+
1022+
10141023
def kron(a, b):
10151024
"""Kronecker product.
10161025
@@ -1042,7 +1051,8 @@ def kron(a, b):
10421051
out_shape = tuple(a.shape * b.shape)
10431052
output_out_of_shape = a_reshaped * b_reshaped
10441053
output_reshaped = output_out_of_shape.reshape(out_shape)
1045-
return output_reshaped
1054+
1055+
return KroneckerProduct(inputs=[a, b], outputs=[output_reshaped])(a, b)
10461056

10471057

10481058
__all__ = [

pytensor/tensor/rewriting/linalg.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
from pytensor.tensor.blockwise import Blockwise
88
from pytensor.tensor.elemwise import DimShuffle
99
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
10-
from pytensor.tensor.nlinalg import MatrixInverse, MatrixPinv, det, inv, pinv
10+
from pytensor.tensor.nlinalg import (
11+
KroneckerProduct,
12+
MatrixInverse,
13+
MatrixPinv,
14+
det,
15+
inv,
16+
kron,
17+
pinv,
18+
)
1119
from pytensor.tensor.rewriting.basic import (
1220
register_canonicalize,
1321
register_specialize,
@@ -16,12 +24,10 @@
1624
from pytensor.tensor.slinalg import (
1725
BlockDiagonal,
1826
Cholesky,
19-
KroneckerProduct,
2027
Solve,
2128
SolveBase,
2229
block_diag,
2330
cholesky,
24-
kron,
2531
solve,
2632
solve_triangular,
2733
)
@@ -348,7 +354,7 @@ def local_lift_through_linalg(fgraph, node):
348354
349355
"""
350356
# TODO: Simplify this if we end up Blockwising KroneckerProduct
351-
if isinstance(node.op.core_op, (MatrixInverse, Cholesky, MatrixPinv)):
357+
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
352358
y = node.inputs[0]
353359
outer_op = node.op
354360

tests/tensor/rewriting/test_linalg.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414
from pytensor.tensor.blockwise import Blockwise
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.math import _allclose, dot, matmul
17-
from pytensor.tensor.nlinalg import Det, MatrixInverse, MatrixPinv, matrix_inverse
17+
from pytensor.tensor.nlinalg import (
18+
Det,
19+
KroneckerProduct,
20+
MatrixInverse,
21+
MatrixPinv,
22+
matrix_inverse,
23+
)
1824
from pytensor.tensor.rewriting.linalg import inv_as_solve
1925
from pytensor.tensor.slinalg import (
2026
BlockDiagonal,
2127
Cholesky,
22-
KroneckerProduct,
2328
Solve,
2429
SolveBase,
2530
SolveTriangular,

0 commit comments

Comments
 (0)