Skip to content

Commit 1acb5aa

Browse files
committed
MAINT: finish up numeric binops/dunders
1 parent d50b626 commit 1acb5aa

File tree

2 files changed

+66
-30
lines changed

2 files changed

+66
-30
lines changed

torch_np/_ndarray.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __len__(self):
182182

183183
### arithmetic ###
184184

185-
# add
185+
# add, self + other
186186
def __add__(self, other):
187187
return _ufunc_impl.add(self, asarray(other))
188188

@@ -193,7 +193,7 @@ def __iadd__(self, other):
193193
return _ufunc_impl.add(self, asarray(other), out=self)
194194

195195

196-
# sub
196+
# sub, self - other
197197
def __sub__(self, other):
198198
return _ufunc_impl.subtract(self, asarray(other))
199199

@@ -204,7 +204,7 @@ def __isub__(self, other):
204204
return _ufunc_impl.subtract(self, asarray(other), out=self)
205205

206206

207-
# mul
207+
# mul, self * other
208208
def __mul__(self, other):
209209
return _ufunc_impl.multiply(self, asarray(other))
210210

@@ -215,30 +215,52 @@ def __imul__(self, other):
215215
return _ufunc_impl.multiply(self, asarray(other), out=self)
216216

217217

218+
# div, self / other
219+
def __truediv__(self, other):
220+
return _ufunc_impl.divide(self, asarray(other))
221+
222+
def __rtruediv__(self, other):
223+
return _ufunc_impl.divide(self, asarray(other))
224+
225+
def __itruediv__(self, other):
226+
return _ufunc_impl.divide(self, asarray(other), out=self)
227+
218228

229+
# floordiv, self // other
219230
def __floordiv__(self, other):
220-
other_tensor = asarray(other).get()
221-
return asarray(self._tensor.__floordiv__(other_tensor))
231+
return _ufunc_impl.floor_divide(self, asarray(other))
232+
233+
def __rfloordiv__(self, other):
234+
return _ufunc_impl.floor_divide(self, asarray(other))
222235

223236
def __ifloordiv__(self, other):
224-
other_tensor = asarray(other).get()
225-
return asarray(self._tensor.__ifloordiv__(other_tensor))
237+
return _ufunc_impl.floor_divide(self, asarray(other), out=self)
226238

227-
def __truediv__(self, other):
228-
other_tensor = asarray(other).get()
229-
return asarray(self._tensor.__truediv__(other_tensor))
230239

231-
def __itruediv__(self, other):
232-
other_tensor = asarray(other).get()
233-
return asarray(self._tensor.__itruediv__(other_tensor))
240+
# power, self**exponent
241+
def __pow__(self, exponent):
242+
return _ufunc_impl.float_power(self, asarray(exponent))
243+
244+
def __rpow__(self, exponent):
245+
return _ufunc_impl.float_power(self, asarray(exponent))
246+
247+
def __ipow__(self, exponent):
248+
return _ufunc_impl.float_power(self, asarray(exponent), out=self)
234249

250+
# remainder, self % other
235251
def __mod__(self, other):
236-
other_tensor = asarray(other).get()
237-
return asarray(self._tensor.__mod__(other_tensor))
252+
return _ufunc_impl.remainder(self, asarray(other))
253+
254+
def __rmod__(self, other):
255+
return _ufunc_impl.remainder(self, asarray(other))
238256

239257
def __imod__(self, other):
240-
other_tensor = asarray(other).get()
241-
return asarray(self._tensor.__imod__(other_tensor))
258+
return _ufunc_impl.remainder(self, asarray(other), out=self)
259+
260+
261+
262+
263+
# FIXME ops and binops below
242264

243265
def __or__(self, other):
244266
other_tensor = asarray(other).get()
@@ -260,9 +282,7 @@ def __neg__(self):
260282
except RuntimeError as e:
261283
raise TypeError(e.args)
262284

263-
def __pow__(self, exponent):
264-
exponent_tensor = asarray(exponent).get()
265-
return asarray(self._tensor.__pow__(exponent_tensor))
285+
266286

267287
### methods to match namespace functions
268288

torch_np/tests/test_ufuncs_basic.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,28 @@ def test_x_and_out_broadcast(self, ufunc):
9393
(np.add, operator.__add__, operator.__iadd__),
9494
(np.subtract, operator.__sub__, operator.__isub__),
9595
(np.multiply, operator.__mul__, operator.__imul__),
96-
# divide
97-
# true_divide?
96+
(np.divide, operator.__truediv__, operator.__itruediv__),
97+
(np.floor_divide, operator.__floordiv__, operator.__ifloordiv__),
98+
(np.float_power, operator.__pow__, operator.__ipow__),
99+
## (np.remainder, operator.__mod__, operator.__imod__), # does not handle complex
100+
101+
98102
# remainder vs fmod?
99103
# pow vs power vs float_power
100104
]
101105

102106
ufuncs_with_dunders = [ufunc for ufunc, _, _ in ufunc_op_iop_numeric]
107+
numeric_binary_ufuncs = [np.float_power, np.power,]
103108

104-
numeric_binary_ufuncs = [np.float_power, np.power,
105109
# these are not implemented for complex inputs
106-
# np.hypot, np.arctan2, np.copysign,
107-
# np.floor_divide, np.fmax, np.fmin, np.fmod,
108-
# np.heaviside, np.logaddexp, np.logaddexp2, np.maximum, np.minimum,
110+
no_complex = [np.floor_divide, np.hypot, np.arctan2, np.copysign, np.fmax,
111+
np.fmin, np.fmod, np.heaviside, np.logaddexp, np.logaddexp2,
112+
np.maximum, np.minimum,
109113
]
110114

111115
parametrize_binary_ufuncs = pytest.mark.parametrize(
112-
'ufunc', ufuncs_with_dunders + numeric_binary_ufuncs)
116+
'ufunc', ufuncs_with_dunders + numeric_binary_ufuncs + no_complex)
117+
113118

114119

115120
# TODO: these snowflakes need special handling
@@ -159,6 +164,9 @@ def test_xy_and_out_casting(self, ufunc, casting, out_dtype):
159164
x, y = self.get_xy(ufunc)
160165
out = np.empty_like(x, dtype=out_dtype)
161166

167+
if ufunc in no_complex and np.issubdtype(out_dtype, np.complexfloating):
168+
pytest.skip(f'{ufunc} does not accept complex.')
169+
162170
can_cast_x = np.can_cast(x, out_dtype, casting=casting)
163171
can_cast_y = np.can_cast(y, out_dtype, casting=casting)
164172

@@ -208,12 +216,13 @@ def test_basic(self, ufunc, op, iop):
208216
assert_equal(op(a.tolist(), a), ufunc(a, a.tolist()))
209217

210218
# __iadd__
211-
a0 = np.array([1, 2, 3])
219+
a0 = np.array([2, 4, 6])
212220
a = a0.copy()
221+
213222
iop(a, 2) # modifies a in-place
214223
assert_equal(a, op(a0, 2))
215224

216-
a0 = np.array([1, 2, 3])
225+
a0 = np.array([2, 4, 6])
217226
a = a0.copy()
218227
iop(a, a)
219228
assert_equal(a, op(a0, a0))
@@ -225,6 +234,9 @@ def test_other_scalar(self, ufunc, op, iop, other_dtype):
225234
a = np.array([1, 2, 3])
226235
b = other_dtype(3)
227236

237+
if ufunc in no_complex and issubclass(other_dtype, np.complexfloating):
238+
pytest.skip(f'{ufunc} does not accept complex.')
239+
228240
# __op__
229241
result = op(a, b)
230242
assert_equal(result, ufunc(a, b))
@@ -253,9 +265,13 @@ def test_other_scalar(self, ufunc, op, iop, other_dtype):
253265
@pytest.mark.parametrize("other_dtype", dtypes_numeric)
254266
def test_other_array(self, ufunc, op, iop, other_dtype):
255267
"""Test op/iop/rop when the other argument is an array of a different dtype."""
256-
# __op__
257268
a = np.array([1, 2, 3])
258269
b = np.array([5, 6, 7], dtype=other_dtype)
270+
271+
if ufunc in no_complex and issubclass(other_dtype, np.complexfloating):
272+
pytest.skip(f'{ufunc} does not accept complex.')
273+
274+
# __op__
259275
result = op(a, b)
260276
assert_equal(result, ufunc(a, b))
261277
assert result.dtype == np.result_type(a, b)

0 commit comments

Comments
 (0)