Skip to content

Commit d8aa07b

Browse files
committed
ENH: add a naive divmod, un-xfail relevant tests
1 parent 35e647b commit d8aa07b

File tree

4 files changed

+94
-44
lines changed

4 files changed

+94
-44
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,58 @@ def wrapped(
5757
decorated.__qualname__ = name # XXX: is this really correct?
5858
decorated.__name__ = name
5959
vars()[name] = decorated
60+
61+
62+
# a stub implementation of divmod, should be improved after
63+
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
64+
#
65+
# Implementation details: we just call two ufuncs which have been created
66+
# just above, for x1 // x2 and x1 % x2.
67+
# This means we are normalizing x1, x2 in each of the ufuncs --- note that there
68+
# is no @normalizer on divmod.
69+
70+
71+
def divmod(
72+
x1,
73+
x2,
74+
/,
75+
out=None,
76+
*,
77+
where=True,
78+
casting="same_kind",
79+
order="K",
80+
dtype=None,
81+
subok: SubokLike = False,
82+
signature=None,
83+
extobj=None,
84+
):
85+
out1, out2 = None, None
86+
if out is not None:
87+
out1, out2 = out
88+
89+
kwds = dict(
90+
where=where,
91+
casting=casting,
92+
order=order,
93+
dtype=dtype,
94+
subok=subok,
95+
signature=signature,
96+
extobj=extobj,
97+
)
98+
99+
# NB: use local names for
100+
quot = floor_divide(x1, x2, out=out1, **kwds)
101+
rem = remainder(x1, x2, out=out2, **kwds)
102+
103+
quot = _helpers.result_or_out(quot.get(), out1) # FIXME: .get() -> .tensor
104+
rem = _helpers.result_or_out(rem.get(), out2)
105+
106+
return quot, rem
107+
108+
109+
def modf(x, /, *args, **kwds):
110+
quot, rem = divmod(x, 1, *args, **kwds)
111+
return rem, quot
112+
113+
114+
__all__ = __all__ + ["divmod", "modf"]

torch_np/_ndarray.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def __rfloordiv__(self, other):
272272
def __ifloordiv__(self, other):
273273
return _binary_ufuncs.floor_divide(self, other, out=self)
274274

275+
__divmod__ = _binary_ufuncs.divmod
276+
275277
# power, self**exponent
276278
__pow__ = __rpow__ = _binary_ufuncs.float_power
277279

torch_np/_normalizations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def normalize_dtype(dtype, name=None):
5050
return torch_dtype
5151

5252

53-
def normalize_subok_like(arg, name):
53+
def normalize_subok_like(arg, name="subok"):
5454
if arg:
5555
raise ValueError(f"'{name}' parameter is not supported.")
5656

@@ -94,7 +94,7 @@ def normalize_this(arg, parm, return_on_failure=_sentinel):
9494
normalizer = normalizers.get(parm.annotation, None)
9595
if normalizer:
9696
try:
97-
return normalizer(arg)
97+
return normalizer(arg, parm.name)
9898
except Exception as exc:
9999
if return_on_failure is not _sentinel:
100100
return return_on_failure

torch_np/tests/numpy_tests/core/test_scalarmath.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,6 @@ class TestModulus:
262262
def test_modulus_basic(self):
263263
dt = np.typecodes['AllInteger'] + np.typecodes['Float']
264264
for op in [floordiv_and_mod, divmod]:
265-
266-
if op == divmod:
267-
pytest.xfail(reason="__divmod__ not implemented")
268-
269265
for dt1, dt2 in itertools.product(dt, dt):
270266
for sg1, sg2 in itertools.product(_signs(dt1), _signs(dt2)):
271267
fmt = 'op: %s, dt1: %s, dt2: %s, sg1: %s, sg2: %s'
@@ -279,7 +275,7 @@ def test_modulus_basic(self):
279275
else:
280276
assert_(b > rem >= 0, msg)
281277

282-
@pytest.mark.xfail(reason='divmod not implemented')
278+
@pytest.mark.slow
283279
def test_float_modulus_exact(self):
284280
# test that float results are exact for small integers. This also
285281
# holds for the same integers scaled by powers of two.
@@ -311,10 +307,6 @@ def test_float_modulus_roundoff(self):
311307
# gh-6127
312308
dt = np.typecodes['Float']
313309
for op in [floordiv_and_mod, divmod]:
314-
315-
if op == divmod:
316-
pytest.xfail(reason="__divmod__ not implemented")
317-
318310
for dt1, dt2 in itertools.product(dt, dt):
319311
for sg1, sg2 in itertools.product((+1, -1), (+1, -1)):
320312
fmt = 'op: %s, dt1: %s, dt2: %s, sg1: %s, sg2: %s'
@@ -329,41 +321,42 @@ def test_float_modulus_roundoff(self):
329321
else:
330322
assert_(b > rem >= 0, msg)
331323

332-
@pytest.mark.skip(reason='float16 on cpu is incomplete in pytorch')
333-
def test_float_modulus_corner_cases(self):
334-
# Check remainder magnitude.
335-
for dt in np.typecodes['Float']:
336-
b = np.array(1.0, dtype=dt)
337-
a = np.nextafter(np.array(0.0, dtype=dt), -b)
338-
rem = operator.mod(a, b)
339-
assert_(rem <= b, 'dt: %s' % dt)
340-
rem = operator.mod(-a, -b)
341-
assert_(rem >= -b, 'dt: %s' % dt)
324+
@pytest.mark.parametrize('dt', np.typecodes['Float'])
325+
def test_float_modulus_corner_cases(self, dt):
326+
if dt == 'e':
327+
pytest.xfail(reason="RuntimeError: 'nextafter_cpu' not implemented for 'Half'")
328+
329+
b = np.array(1.0, dtype=dt)
330+
a = np.nextafter(np.array(0.0, dtype=dt), -b)
331+
rem = operator.mod(a, b)
332+
assert_(rem <= b, 'dt: %s' % dt)
333+
rem = operator.mod(-a, -b)
334+
assert_(rem >= -b, 'dt: %s' % dt)
342335

343336
# Check nans, inf
344-
with suppress_warnings() as sup:
345-
sup.filter(RuntimeWarning, "invalid value encountered in remainder")
346-
sup.filter(RuntimeWarning, "divide by zero encountered in remainder")
347-
sup.filter(RuntimeWarning, "divide by zero encountered in floor_divide")
348-
sup.filter(RuntimeWarning, "divide by zero encountered in divmod")
349-
sup.filter(RuntimeWarning, "invalid value encountered in divmod")
350-
for dt in np.typecodes['Float']:
351-
fone = np.array(1.0, dtype=dt)
352-
fzer = np.array(0.0, dtype=dt)
353-
finf = np.array(np.inf, dtype=dt)
354-
fnan = np.array(np.nan, dtype=dt)
355-
rem = operator.mod(fone, fzer)
356-
assert_(np.isnan(rem), 'dt: %s' % dt)
357-
# MSVC 2008 returns NaN here, so disable the check.
358-
#rem = operator.mod(fone, finf)
359-
#assert_(rem == fone, 'dt: %s' % dt)
360-
rem = operator.mod(fone, fnan)
361-
assert_(np.isnan(rem), 'dt: %s' % dt)
362-
rem = operator.mod(finf, fone)
363-
assert_(np.isnan(rem), 'dt: %s' % dt)
364-
for op in [floordiv_and_mod, divmod]:
365-
div, mod = op(fone, fzer)
366-
assert_(np.isinf(div)) and assert_(np.isnan(mod))
337+
# with suppress_warnings() as sup:
338+
# sup.filter(RuntimeWarning, "invalid value encountered in remainder")
339+
# sup.filter(RuntimeWarning, "divide by zero encountered in remainder")
340+
# sup.filter(RuntimeWarning, "divide by zero encountered in floor_divide")
341+
# sup.filter(RuntimeWarning, "divide by zero encountered in divmod")
342+
# sup.filter(RuntimeWarning, "invalid value encountered in divmod")
343+
for dt in np.typecodes['Float']:
344+
fone = np.array(1.0, dtype=dt)
345+
fzer = np.array(0.0, dtype=dt)
346+
finf = np.array(np.inf, dtype=dt)
347+
fnan = np.array(np.nan, dtype=dt)
348+
rem = operator.mod(fone, fzer)
349+
assert_(np.isnan(rem), 'dt: %s' % dt)
350+
# MSVC 2008 returns NaN here, so disable the check.
351+
#rem = operator.mod(fone, finf)
352+
#assert_(rem == fone, 'dt: %s' % dt)
353+
rem = operator.mod(fone, fnan)
354+
assert_(np.isnan(rem), 'dt: %s' % dt)
355+
rem = operator.mod(finf, fone)
356+
assert_(np.isnan(rem), 'dt: %s' % dt)
357+
for op in [floordiv_and_mod, divmod]:
358+
div, mod = op(fone, fzer)
359+
assert_(np.isinf(div)) and assert_(np.isnan(mod))
367360

368361

369362
class TestComplexDivision:

0 commit comments

Comments
 (0)