18
18
19
19
20
20
class MatrixPinv (Op ):
21
- """Computes the pseudo-inverse of a matrix :math:`A`.
22
-
23
- The pseudo-inverse of a matrix :math:`A`, denoted :math:`A^+`, is
24
- defined as: "the matrix that 'solves' [the least-squares problem]
25
- :math:`Ax = b`," i.e., if :math:`\\ bar{x}` is said solution, then
26
- :math:`A^+` is that matrix such that :math:`\\ bar{x} = A^+b`.
27
-
28
- Note that :math:`Ax=AA^+b`, so :math:`AA^+` is close to the identity matrix.
29
- This method is not faster than `matrix_inverse`. Its strength comes from
30
- that it works for non-square matrices.
31
- If you have a square matrix though, `matrix_inverse` can be both more
32
- exact and faster to compute. Also this op does not get optimized into a
33
- solve op.
21
+ __props__ = ("hermitian" ,)
34
22
35
- """
36
-
37
- __props__ = ()
38
-
39
- def __init__ (self ):
40
- pass
23
+ def __init__ (self , hermitian ):
24
+ self .hermitian = hermitian
41
25
42
26
def make_node (self , x ):
43
27
x = as_tensor_variable (x )
@@ -47,7 +31,7 @@ def make_node(self, x):
47
31
def perform (self , node , inputs , outputs ):
48
32
(x ,) = inputs
49
33
(z ,) = outputs
50
- z [0 ] = np .linalg .pinv (x ).astype (x .dtype )
34
+ z [0 ] = np .linalg .pinv (x , hermitian = self . hermitian ).astype (x .dtype )
51
35
52
36
def L_op (self , inputs , outputs , g_outputs ):
53
37
r"""The gradient function should return
@@ -75,8 +59,46 @@ def L_op(self, inputs, outputs, g_outputs):
75
59
).T
76
60
return [grad ]
77
61
62
+ def infer_shape (self , fgraph , node , shapes ):
63
+ return [list (reversed (shapes [0 ]))]
64
+
65
+
66
+ def pinv (x , hermitian = False ):
67
+ """Computes the pseudo-inverse of a matrix :math:`A`.
68
+
69
+ The pseudo-inverse of a matrix :math:`A`, denoted :math:`A^+`, is
70
+ defined as: "the matrix that 'solves' [the least-squares problem]
71
+ :math:`Ax = b`," i.e., if :math:`\\ bar{x}` is said solution, then
72
+ :math:`A^+` is that matrix such that :math:`\\ bar{x} = A^+b`.
73
+
74
+ Note that :math:`Ax=AA^+b`, so :math:`AA^+` is close to the identity matrix.
75
+ This method is not faster than `matrix_inverse`. Its strength comes from
76
+ that it works for non-square matrices.
77
+ If you have a square matrix though, `matrix_inverse` can be both more
78
+ exact and faster to compute. Also this op does not get optimized into a
79
+ solve op.
80
+
81
+ """
82
+ return MatrixPinv (hermitian = hermitian )(x )
83
+
84
+
85
+ class Inv (Op ):
86
+ """Computes the inverse of one or more matrices."""
87
+
88
+ def make_node (self , x ):
89
+ x = as_tensor_variable (x )
90
+ return Apply (self , [x ], [x .type ()])
91
+
92
+ def perform (self , node , inputs , outputs ):
93
+ (x ,) = inputs
94
+ (z ,) = outputs
95
+ z [0 ] = np .linalg .inv (x ).astype (x .dtype )
96
+
97
+ def infer_shape (self , fgraph , node , shapes ):
98
+ return shapes
99
+
78
100
79
- pinv = MatrixPinv ()
101
+ inv = Inv ()
80
102
81
103
82
104
class MatrixInverse (Op ):
0 commit comments