2
2
from pytest import raises as assert_raises
3
3
4
4
import torch_np as np
5
- from torch_np .testing import assert_equal , assert_array_equal , assert_allclose
5
+ from torch_np .testing import (assert_equal , assert_array_equal , assert_allclose ,
6
+ assert_almost_equal )
6
7
7
8
import torch_np ._util as _util
8
9
@@ -321,12 +322,10 @@ def test_sum(self):
321
322
def test_sum_stability (self ):
322
323
a = np .ones (500 , dtype = np .float32 )
323
324
zero = np .zeros (1 , dtype = 'float32' )[0 ]
324
- assert_allclose ((a / 10. ).sum () - a .size / 10. , zero , atol = 1.5e-4 ,
325
- check_dtype = False )
325
+ assert_allclose ((a / 10. ).sum () - a .size / 10. , zero , atol = 1.5e-4 )
326
326
327
327
a = np .ones (500 , dtype = np .float64 )
328
- assert_allclose ((a / 10. ).sum () - a .size / 10. , 0. , atol = 1.5e-13 ,
329
- check_dtype = False )
328
+ assert_allclose ((a / 10. ).sum () - a .size / 10. , 0. , atol = 1.5e-13 )
330
329
331
330
def test_sum_boolean (self ):
332
331
a = (np .arange (7 ) % 2 == 0 )
@@ -338,8 +337,8 @@ def test_sum_boolean(self):
338
337
assert res_float .dtype == 'float64'
339
338
340
339
341
- @pytest .mark .xfail (reason = "dtype(value) needs implementing " )
342
- def test_sum_dtypes (self ):
340
+ @pytest .mark .xfail (reason = "sum: does not warn on overflow " )
341
+ def test_sum_dtypes_warnings (self ):
343
342
for dt in (int , np .float16 , np .float32 , np .float64 ):
344
343
for v in (0 , 1 , 2 , 7 , 8 , 9 , 15 , 16 , 19 , 127 ,
345
344
128 , 1024 , 1235 ):
@@ -357,48 +356,54 @@ def test_sum_dtypes(self):
357
356
assert_almost_equal (np .sum (d ), tgt )
358
357
assert_equal (len (w ), 2 * overflow )
359
358
360
- assert_almost_equal (np .sum (d [:: - 1 ] ), tgt )
359
+ assert_almost_equal (np .sum (np . flip ( d ) ), tgt )
361
360
assert_equal (len (w ), 3 * overflow )
362
361
362
+ def test_sum_dtypes_2 (self ):
363
+ for dt in (int , np .float16 , np .float32 , np .float64 ):
363
364
d = np .ones (500 , dtype = dt )
364
365
assert_almost_equal (np .sum (d [::2 ]), 250. )
365
366
assert_almost_equal (np .sum (d [1 ::2 ]), 250. )
366
367
assert_almost_equal (np .sum (d [::3 ]), 167. )
367
368
assert_almost_equal (np .sum (d [1 ::3 ]), 167. )
368
- assert_almost_equal (np .sum (d [::- 2 ]), 250. )
369
- assert_almost_equal (np .sum (d [- 1 ::- 2 ]), 250. )
370
- assert_almost_equal (np .sum (d [::- 3 ]), 167. )
371
- assert_almost_equal (np .sum (d [- 1 ::- 3 ]), 167. )
369
+ assert_almost_equal (np .sum (np .flip (d )[::2 ]), 250. )
370
+
371
+ assert_almost_equal (np .sum (np .flip (d )[1 ::2 ]), 250. )
372
+
373
+ assert_almost_equal (np .sum (np .flip (d )[::3 ]), 167. )
374
+ assert_almost_equal (np .sum (np .flip (d )[1 ::3 ]), 167. )
375
+
372
376
# sum with first reduction entry != 0
373
377
d = np .ones ((1 ,), dtype = dt )
374
378
d += d
375
379
assert_almost_equal (d , 2. )
376
380
377
- @pytest .mark .xfail (reason = "dtype(value) needs implementing" )
378
- def test_sum_complex (self ):
379
- for dt in (np .complex64 , np .complex128 ):
380
- for v in (0 , 1 , 2 , 7 , 8 , 9 , 15 , 16 , 19 , 127 ,
381
- 128 , 1024 , 1235 ):
382
- tgt = dt (v * (v + 1 ) / 2 ) - dt ((v * (v + 1 ) / 2 ) * 1j )
383
- d = np .empty (v , dtype = dt )
384
- d .real = np .arange (1 , v + 1 )
385
- d .imag = - np .arange (1 , v + 1 )
386
- assert_allclose (np .sum (d ), tgt , atol = 1.5e-5 )
387
- assert_allcllose (np .sum (d [::- 1 ]), tgt , atol = 1.5e-7 )
388
-
389
- d = np .ones (500 , dtype = dt ) + 1j
390
- assert_allclose (np .sum (d [::2 ]), 250. + 250j , atol = 1.5e-7 )
391
- assert_allclose (np .sum (d [1 ::2 ]), 250. + 250j , atol = 1.5e-7 )
392
- assert_allclose (np .sum (d [::3 ]), 167. + 167j , atol = 1.5e-7 )
393
- assert_allclose (np .sum (d [1 ::3 ]), 167. + 167j , atol = 1.5e-7 )
394
- assert_allclose (np .sum (d [::- 2 ]), 250. + 250j , atol = 1.5e-7 )
395
- assert_allclose (np .sum (d [- 1 ::- 2 ]), 250. + 250j , atol = 1.5e-7 )
396
- assert_allclose (np .sum (d [::- 3 ]), 167. + 167j , atol = 1.5e-7 )
397
- assert_allclose (np .sum (d [- 1 ::- 3 ]), 167. + 167j , atol = 1.5e-7 )
398
- # sum with first reduction entry != 0
399
- d = np .ones ((1 ,), dtype = dt ) + 1j
400
- d += d
401
- assert_allclose (d , 2. + 2j , atol = 1.5e-7 )
381
+ @pytest .mark .parametrize ("dt" , [np .complex64 , np .complex128 ])
382
+ def test_sum_complex_1 (self , dt ):
383
+ for v in (0 , 1 , 2 , 7 , 8 , 9 , 15 , 16 , 19 , 127 ,
384
+ 128 , 1024 , 1235 ):
385
+ tgt = dt (v * (v + 1 ) / 2 ) - dt ((v * (v + 1 ) / 2 ) * 1j )
386
+ d = np .empty (v , dtype = dt )
387
+ d .real = np .arange (1 , v + 1 )
388
+ d .imag = - np .arange (1 , v + 1 )
389
+ assert_allclose (np .sum (d ), tgt , atol = 1.5e-5 )
390
+ assert_allclose (np .sum (np .flip (d )), tgt , atol = 1.5e-7 )
391
+
392
+ @pytest .mark .parametrize ("dt" , [np .complex64 , np .complex128 ])
393
+ def test_sum_complex_2 (self , dt ):
394
+ d = np .ones (500 , dtype = dt ) + 1j
395
+ assert_allclose (np .sum (d [::2 ]), 250. + 250j , atol = 1.5e-7 )
396
+ assert_allclose (np .sum (d [1 ::2 ]), 250. + 250j , atol = 1.5e-7 )
397
+ assert_allclose (np .sum (d [::3 ]), 167. + 167j , atol = 1.5e-7 )
398
+ assert_allclose (np .sum (d [1 ::3 ]), 167. + 167j , atol = 1.5e-7 )
399
+ assert_allclose (np .sum (np .flip (d )[::2 ]), 250. + 250j , atol = 1.5e-7 )
400
+ assert_allclose (np .sum (np .flip (d )[1 ::2 ]), 250. + 250j , atol = 1.5e-7 )
401
+ assert_allclose (np .sum (np .flip (d )[::3 ]), 167. + 167j , atol = 1.5e-7 )
402
+ assert_allclose (np .sum (np .flip (d )[1 ::3 ]), 167. + 167j , atol = 1.5e-7 )
403
+ # sum with first reduction entry != 0
404
+ d = np .ones ((1 ,), dtype = dt ) + 1j
405
+ d += d
406
+ assert_allclose (d , 2. + 2j , atol = 1.5e-7 )
402
407
403
408
@pytest .mark .xfail (reason = 'initial=... need implementing' )
404
409
def test_sum_initial (self ):
0 commit comments