Skip to content

Commit 39248a1

Browse files
committed
MAINT: delegate ndarray unary dunders to their ufuncs
1 parent 6815005 commit 39248a1

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

torch_np/_ndarray.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def __rpow__(self, exponent):
247247
def __ipow__(self, exponent):
248248
return _ufunc_impl.float_power(self, asarray(exponent), out=self)
249249

250+
250251
# remainder, self % other
251252
def __mod__(self, other):
252253
return _ufunc_impl.remainder(self, asarray(other))
@@ -270,17 +271,19 @@ def __ior__(self, other):
270271
other_tensor = asarray(other).get()
271272
return asarray(self._tensor.__ior__(other_tensor))
272273

274+
275+
# unary ops
273276
def __invert__(self):
274-
return asarray(self._tensor.__invert__())
277+
return _ufunc_impl.invert(self)
275278

276279
def __abs__(self):
277-
return asarray(self._tensor.__abs__())
280+
return _ufunc_impl.absolute(self)
281+
282+
def __pos__(self):
283+
return _ufunc_impl.positive(self)
278284

279285
def __neg__(self):
280-
try:
281-
return asarray(self._tensor.__neg__())
282-
except RuntimeError as e:
283-
raise TypeError(e.args)
286+
return _ufunc_impl.negative(self)
284287

285288

286289

torch_np/_unary_ufuncs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from . import _ufunc_impl
1111

12-
__all__ = ['abs', 'absolute', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh', 'asarray', 'cbrt', 'ceil', 'conj', 'conjugate', 'cos', 'cosh', 'deg2rad', 'degrees', 'exp', 'exp2', 'expm1', 'fabs', 'floor', 'isfinite', 'isinf', 'isnan', 'log', 'log10', 'log1p', 'log2', 'logical_not', 'negative', 'positive', 'rad2deg', 'radians', 'reciprocal', 'rint', 'sign', 'signbit', 'sin', 'sinh', 'sqrt', 'square', 'tan', 'tanh', 'trunc']
12+
__all__ = ['abs', 'absolute', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh', 'asarray', 'cbrt', 'ceil', 'conj', 'conjugate', 'cos', 'cosh', 'deg2rad', 'degrees', 'exp', 'exp2', 'expm1', 'fabs', 'floor', 'isfinite', 'isinf', 'isnan', 'log', 'log10', 'log1p', 'log2', 'logical_not', 'negative', 'positive', 'rad2deg', 'radians', 'reciprocal', 'rint', 'sign', 'signbit', 'sin', 'sinh', 'sqrt', 'square', 'tan', 'tanh', 'trunc', 'invert']
1313

1414

1515

torch_np/tests/numpy_tests/core/test_scalarmath.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,8 @@ def __array__(self):
619619
class TestNegative:
620620
def test_exceptions(self):
621621
a = np.ones((), dtype=np.bool_)[()]
622-
assert_raises(TypeError, operator.neg, a)
622+
# XXX: TypeError from numpy, RuntimeError from torch
623+
assert_raises((TypeError, RuntimeError), operator.neg, a)
623624

624625
def test_result(self):
625626
types = np.typecodes['AllInteger'] + np.typecodes['AllFloat']

0 commit comments

Comments
 (0)