1
+ from functools import partial
2
+
1
3
import numpy as np
2
4
import numpy .linalg
3
5
import pytest
34
36
lscalar ,
35
37
matrix ,
36
38
scalar ,
39
+ tensor ,
37
40
tensor3 ,
38
41
tensor4 ,
39
42
vector ,
@@ -150,29 +153,52 @@ def test_qr_modes():
150
153
151
154
class TestSvd (utt .InferShapeTester ):
152
155
op_class = SVD
153
- dtype = "float32"
154
156
155
157
def setup_method (self ):
156
158
super ().setup_method ()
157
159
self .rng = np .random .default_rng (utt .fetch_seed ())
158
- self .A = matrix (dtype = self . dtype )
160
+ self .A = matrix (dtype = config . floatX )
159
161
self .op = svd
160
162
161
- def test_svd (self ):
162
- A = matrix ("A" , dtype = self .dtype )
163
- U , S , VT = svd (A )
164
- fn = function ([A ], [U , S , VT ])
165
- a = self .rng .random ((4 , 4 )).astype (self .dtype )
166
- n_u , n_s , n_vt = np .linalg .svd (a )
167
- t_u , t_s , t_vt = fn (a )
163
+ @pytest .mark .parametrize (
164
+ "core_shape" , [(3 , 3 ), (4 , 3 ), (3 , 4 )], ids = ["square" , "tall" , "wide" ]
165
+ )
166
+ @pytest .mark .parametrize (
167
+ "full_matrix" , [True , False ], ids = ["full=True" , "full=False" ]
168
+ )
169
+ @pytest .mark .parametrize (
170
+ "compute_uv" , [True , False ], ids = ["compute_uv=True" , "compute_uv=False" ]
171
+ )
172
+ @pytest .mark .parametrize (
173
+ "batched" , [True , False ], ids = ["batched=True" , "batched=False" ]
174
+ )
175
+ @pytest .mark .parametrize (
176
+ "test_imag" , [True , False ], ids = ["test_imag=True" , "test_imag=False" ]
177
+ )
178
+ def test_svd (self , core_shape , full_matrix , compute_uv , batched , test_imag ):
179
+ dtype = config .floatX
180
+ if test_imag :
181
+ dtype = "complex128" if dtype .endswith ("64" ) else "complex64"
182
+ shape = core_shape if not batched else (10 , * core_shape )
183
+ A = tensor ("A" , shape = shape , dtype = dtype )
184
+ a = self .rng .random (shape ).astype (dtype )
185
+
186
+ outputs = svd (A , compute_uv = compute_uv , full_matrices = full_matrix )
187
+ outputs = outputs if isinstance (outputs , list ) else [outputs ]
188
+ fn = function (inputs = [A ], outputs = outputs )
189
+
190
+ np_fn = np .vectorize (
191
+ partial (np .linalg .svd , compute_uv = compute_uv , full_matrices = full_matrix ),
192
+ signature = outputs [0 ].owner .op .core_op .gufunc_signature ,
193
+ )
194
+
195
+ np_outputs = np_fn (a )
196
+ pt_outputs = fn (a )
168
197
169
- assert _allclose (n_u , t_u )
170
- assert _allclose (n_s , t_s )
171
- assert _allclose (n_vt , t_vt )
198
+ np_outputs = np_outputs if isinstance (np_outputs , tuple ) else [np_outputs ]
172
199
173
- fn = function ([A ], svd (A , compute_uv = False ))
174
- t_s = fn (a )
175
- assert _allclose (n_s , t_s )
200
+ for np_val , pt_val in zip (np_outputs , pt_outputs ):
201
+ assert _allclose (np_val , pt_val )
176
202
177
203
def test_svd_infer_shape (self ):
178
204
self .validate_shape ((4 , 4 ), full_matrices = True , compute_uv = True )
@@ -183,7 +209,7 @@ def test_svd_infer_shape(self):
183
209
184
210
def validate_shape (self , shape , compute_uv = True , full_matrices = True ):
185
211
A = self .A
186
- A_v = self .rng .random (shape ).astype (self . dtype )
212
+ A_v = self .rng .random (shape ).astype (config . floatX )
187
213
outputs = self .op (A , full_matrices = full_matrices , compute_uv = compute_uv )
188
214
if not compute_uv :
189
215
outputs = [outputs ]
@@ -451,8 +477,8 @@ def test_non_tensorial_input(self):
451
477
norm (3 , None )
452
478
453
479
def test_tensor_input (self ):
454
- with pytest . raises ( NotImplementedError ):
455
- norm ( np . random . random ( (3 , 4 , 5 )), None )
480
+ res = norm ( np . random . random (( 3 , 4 , 5 )), None )
481
+ assert res . shape . eval () == (3 ,)
456
482
457
483
def test_numpy_compare (self ):
458
484
rng = np .random .default_rng (utt .fetch_seed ())
0 commit comments