@@ -216,21 +216,29 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
216
216
self ._compile_and_check ([A ], outputs , [A_v ], self .op_class , warn = False )
217
217
218
218
@pytest .mark .parametrize (
219
- "compute_uv, full_matrices" ,
220
- [(True , False ), (False , False ), (True , True )],
221
- ids = [
222
- "compute_uv=True, full_matrices=False" ,
223
- "compute_uv=False, full_matrices=False" ,
224
- "compute_uv=True, full_matrices=True" ,
225
- ],
219
+ "compute_uv, full_matrices, gradient_test_case" ,
220
+ [(False , False , 0 )]
221
+ + [(True , False , i ) for i in range (7 )]
222
+ + [(True , True , i ) for i in range (7 )],
223
+ ids = (
224
+ ["compute_uv=False, full_matrices=False" ]
225
+ + [
226
+ f"compute_uv=True, full_matrices=False, gradient={ grad } "
227
+ for grad in ["U" , "s" , "V" , "U+s" , "s+V" , "U+V" , "U+s+V" ]
228
+ ]
229
+ + [
230
+ f"compute_uv=True, full_matrices=True, gradient={ grad } "
231
+ for grad in ["U" , "s" , "V" , "U+s" , "s+V" , "U+V" , "U+s+V" ]
232
+ ]
233
+ ),
226
234
)
227
235
@pytest .mark .parametrize (
228
236
"shape" , [(3 , 3 ), (4 , 3 ), (3 , 4 )], ids = ["(3,3)" , "(4,3)" , "(3,4)" ]
229
237
)
230
238
@pytest .mark .parametrize (
231
239
"batched" , [True , False ], ids = ["batched=True" , "batched=False" ]
232
240
)
233
- def test_grad (self , compute_uv , full_matrices , shape , batched ):
241
+ def test_grad (self , compute_uv , full_matrices , gradient_test_case , shape , batched ):
234
242
rng = np .random .default_rng (utt .fetch_seed ())
235
243
if batched :
236
244
shape = (4 , * shape )
@@ -248,15 +256,29 @@ def test_grad(self, compute_uv, full_matrices, shape, batched):
248
256
249
257
elif compute_uv :
250
258
251
- def svd_fn (A ):
259
+ def svd_fn (A , case = 0 ):
252
260
U , s , V = svd (A , compute_uv = compute_uv , full_matrices = full_matrices )
253
- return U .sum () + s .sum () + V .sum ()
254
-
255
- utt .verify_grad (
256
- svd_fn ,
257
- [A_v ],
258
- rng = rng ,
259
- )
261
+ if case == 0 :
262
+ return U .sum ()
263
+ elif case == 1 :
264
+ return s .sum ()
265
+ elif case == 2 :
266
+ return V .sum ()
267
+ elif case == 3 :
268
+ return U .sum () + s .sum ()
269
+ elif case == 4 :
270
+ return s .sum () + V .sum ()
271
+ elif case == 5 :
272
+ return U .sum () + V .sum ()
273
+ elif case == 6 :
274
+ return U .sum () + s .sum () + V .sum ()
275
+
276
+ for case in range (7 ):
277
+ utt .verify_grad (
278
+ partial (svd_fn , case = gradient_test_case ),
279
+ [A_v ],
280
+ rng = rng ,
281
+ )
260
282
261
283
else :
262
284
utt .verify_grad (
0 commit comments