Skip to content

Commit a6e461b

Browse files
Add missing hermitian option to MatrixPinv
1 parent 40313aa commit a6e461b

File tree

1 file changed

+43
-21
lines changed

1 file changed

+43
-21
lines changed

aesara/tensor/nlinalg.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,10 @@
1818

1919

2020
class MatrixPinv(Op):
21-
"""Computes the pseudo-inverse of a matrix :math:`A`.
22-
23-
The pseudo-inverse of a matrix :math:`A`, denoted :math:`A^+`, is
24-
defined as: "the matrix that 'solves' [the least-squares problem]
25-
:math:`Ax = b`," i.e., if :math:`\\bar{x}` is said solution, then
26-
:math:`A^+` is that matrix such that :math:`\\bar{x} = A^+b`.
27-
28-
Note that :math:`Ax=AA^+b`, so :math:`AA^+` is close to the identity matrix.
29-
This method is not faster than `matrix_inverse`. Its strength comes from
30-
that it works for non-square matrices.
31-
If you have a square matrix though, `matrix_inverse` can be both more
32-
exact and faster to compute. Also this op does not get optimized into a
33-
solve op.
21+
__props__ = ("hermitian",)
3422

35-
"""
36-
37-
__props__ = ()
38-
39-
def __init__(self):
40-
pass
23+
def __init__(self, hermitian):
24+
self.hermitian = hermitian
4125

4226
def make_node(self, x):
4327
x = as_tensor_variable(x)
@@ -47,7 +31,7 @@ def make_node(self, x):
4731
def perform(self, node, inputs, outputs):
4832
(x,) = inputs
4933
(z,) = outputs
50-
z[0] = np.linalg.pinv(x).astype(x.dtype)
34+
z[0] = np.linalg.pinv(x, hermitian=self.hermitian).astype(x.dtype)
5135

5236
def L_op(self, inputs, outputs, g_outputs):
5337
r"""The gradient function should return
@@ -75,8 +59,46 @@ def L_op(self, inputs, outputs, g_outputs):
7559
).T
7660
return [grad]
7761

62+
def infer_shape(self, fgraph, node, shapes):
63+
return [list(reversed(shapes[0]))]
64+
65+
66+
def pinv(x, hermitian=False):
67+
"""Computes the pseudo-inverse of a matrix :math:`A`.
68+
69+
The pseudo-inverse of a matrix :math:`A`, denoted :math:`A^+`, is
70+
defined as: "the matrix that 'solves' [the least-squares problem]
71+
:math:`Ax = b`," i.e., if :math:`\\bar{x}` is said solution, then
72+
:math:`A^+` is that matrix such that :math:`\\bar{x} = A^+b`.
73+
74+
Note that :math:`Ax=AA^+b`, so :math:`AA^+` is close to the identity matrix.
75+
This method is not faster than `matrix_inverse`. Its strength comes from
76+
that it works for non-square matrices.
77+
If you have a square matrix though, `matrix_inverse` can be both more
78+
exact and faster to compute. Also this op does not get optimized into a
79+
solve op.
80+
81+
"""
82+
return MatrixPinv(hermitian=hermitian)(x)
83+
84+
85+
class Inv(Op):
86+
"""Computes the inverse of one or more matrices."""
87+
88+
def make_node(self, x):
89+
x = as_tensor_variable(x)
90+
return Apply(self, [x], [x.type()])
91+
92+
def perform(self, node, inputs, outputs):
93+
(x,) = inputs
94+
(z,) = outputs
95+
z[0] = np.linalg.inv(x).astype(x.dtype)
96+
97+
def infer_shape(self, fgraph, node, shapes):
98+
return shapes
99+
78100

79-
pinv = MatrixPinv()
101+
inv = Inv()
80102

81103

82104
class MatrixInverse(Op):

0 commit comments

Comments
 (0)