Skip to content

Commit 2b95ac2

Browse files
committed
MAINT: address review comments
1 parent 1f1973d commit 2b95ac2

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

torch_np/tests/numpy_tests/core/test_scalarmath.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,11 @@ class TestModulus:
261261

262262
def test_modulus_basic(self):
263263
dt = np.typecodes['AllInteger'] + np.typecodes['Float']
264-
for op in [floordiv_and_mod,]: # TODO: divmod is not implemented
264+
for op in [floordiv_and_mod, divmod]:
265+
266+
if op == divmod:
267+
pytest.xfail(reason="__divmod__ not implemented")
268+
265269
for dt1, dt2 in itertools.product(dt, dt):
266270
for sg1, sg2 in itertools.product(_signs(dt1), _signs(dt2)):
267271
fmt = 'op: %s, dt1: %s, dt2: %s, sg1: %s, sg2: %s'
@@ -306,7 +310,11 @@ def test_float_modulus_exact(self):
306310
def test_float_modulus_roundoff(self):
307311
# gh-6127
308312
dt = np.typecodes['Float']
309-
for op in [floordiv_and_mod]: # TODO divmod is not implemented
313+
for op in [floordiv_and_mod, divmod]:
314+
315+
if op == divmod:
316+
pytest.xfail(reason="__divmod__ not implemented")
317+
310318
for dt1, dt2 in itertools.product(dt, dt):
311319
for sg1, sg2 in itertools.product((+1, -1), (+1, -1)):
312320
fmt = 'op: %s, dt1: %s, dt2: %s, sg1: %s, sg2: %s'
@@ -360,7 +368,7 @@ def test_float_modulus_corner_cases(self):
360368

361369
class TestComplexDivision:
362370

363-
@pytest.mark.xfail(reason='With pytorch, 1/(0+0j) is nan + nan*j, not inf + nan*j')
371+
@pytest.mark.skip(reason='With pytorch, 1/(0+0j) is nan + nan*j, not inf + nan*j')
364372
def test_zero_division(self):
365373
with np.errstate(all="ignore"):
366374
for t in [np.complex64, np.complex128]:
@@ -441,7 +449,6 @@ def test_int_from_long(self):
441449
a = np.array(l, dtype=T)
442450
assert_equal([int(_m) for _m in a], li)
443451

444-
445452
@pytest.mark.xfail(reason="pytorch does not emit this warning.")
446453
def test_iinfo_long_values_1(self):
447454
for code in 'bBh':
@@ -681,7 +688,7 @@ def test_numpy_abs(self, dtype):
681688
self._test_abs_func(np.abs, dtype)
682689

683690

684-
@pytest.mark.skip(reason='TODO: implement bit shifts')
691+
@pytest.mark.xfail(reason='TODO: implement bit shifts')
685692
class TestBitShifts:
686693

687694
@pytest.mark.parametrize('type_code', np.typecodes['AllInteger'])

0 commit comments

Comments
 (0)