Skip to content

Commit 71bb501

Browse files
committed
MAINT: finish up numeric binops/dunders
1 parent d50b626 commit 71bb501

File tree

2 files changed

+56
-26
lines changed

2 files changed

+56
-26
lines changed

torch_np/_ndarray.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __len__(self):
182182

183183
### arithmetic ###
184184

185-
# add
185+
# add, self + other
186186
def __add__(self, other):
187187
return _ufunc_impl.add(self, asarray(other))
188188

@@ -193,7 +193,7 @@ def __iadd__(self, other):
193193
return _ufunc_impl.add(self, asarray(other), out=self)
194194

195195

196-
# sub
196+
# sub, self - other
197197
def __sub__(self, other):
198198
return _ufunc_impl.subtract(self, asarray(other))
199199

@@ -204,7 +204,7 @@ def __isub__(self, other):
204204
return _ufunc_impl.subtract(self, asarray(other), out=self)
205205

206206

207-
# mul
207+
# mul, self * other
208208
def __mul__(self, other):
209209
return _ufunc_impl.multiply(self, asarray(other))
210210

@@ -215,23 +215,40 @@ def __imul__(self, other):
215215
return _ufunc_impl.multiply(self, asarray(other), out=self)
216216

217217

218+
# div, self / other
219+
def __truediv__(self, other):
220+
return _ufunc_impl.divide(self, asarray(other))
221+
222+
def __rtruediv__(self, other):
223+
return _ufunc_impl.divide(self, asarray(other))
224+
225+
def __itruediv__(self, other):
226+
return _ufunc_impl.divide(self, asarray(other), out=self)
227+
218228

229+
# floordiv, self // other
219230
def __floordiv__(self, other):
220-
other_tensor = asarray(other).get()
221-
return asarray(self._tensor.__floordiv__(other_tensor))
231+
return _ufunc_impl.floor_divide(self, asarray(other))
232+
233+
def __rfloordiv__(self, other):
234+
return _ufunc_impl.floor_divide(self, asarray(other))
222235

223236
def __ifloordiv__(self, other):
224-
other_tensor = asarray(other).get()
225-
return asarray(self._tensor.__ifloordiv__(other_tensor))
237+
return _ufunc_impl.floor_divide(self, asarray(other), out=self)
226238

227-
def __truediv__(self, other):
228-
other_tensor = asarray(other).get()
229-
return asarray(self._tensor.__truediv__(other_tensor))
230239

231-
def __itruediv__(self, other):
232-
other_tensor = asarray(other).get()
233-
return asarray(self._tensor.__itruediv__(other_tensor))
240+
# power, self**exponent
241+
def __pow__(self, exponent):
242+
return _ufunc_impl.float_power(self, asarray(exponent))
243+
244+
def __rpow__(self, exponent):
245+
return _ufunc_impl.float_power(self, asarray(exponent))
234246

247+
def __ipow__(self, exponent):
248+
return _ufunc_impl.float_power(self, asarray(exponent), out=self)
249+
250+
251+
# FIXME ops and binops below
235252
def __mod__(self, other):
236253
other_tensor = asarray(other).get()
237254
return asarray(self._tensor.__mod__(other_tensor))
@@ -260,9 +277,7 @@ def __neg__(self):
260277
except RuntimeError as e:
261278
raise TypeError(e.args)
262279

263-
def __pow__(self, exponent):
264-
exponent_tensor = asarray(exponent).get()
265-
return asarray(self._tensor.__pow__(exponent_tensor))
280+
266281

267282
### methods to match namespace functions
268283

torch_np/tests/test_ufuncs_basic.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,27 @@ def test_x_and_out_broadcast(self, ufunc):
9393
(np.add, operator.__add__, operator.__iadd__),
9494
(np.subtract, operator.__sub__, operator.__isub__),
9595
(np.multiply, operator.__mul__, operator.__imul__),
96-
# divide
97-
# true_divide?
96+
(np.divide, operator.__truediv__, operator.__itruediv__),
97+
(np.floor_divide, operator.__floordiv__, operator.__ifloordiv__),
98+
(np.float_power, operator.__pow__, operator.__ipow__),
99+
100+
98101
# remainder vs fmod?
99102
# pow vs power vs float_power
100103
]
101104

102105
ufuncs_with_dunders = [ufunc for ufunc, _, _ in ufunc_op_iop_numeric]
106+
numeric_binary_ufuncs = [np.float_power, np.power,]
103107

104-
numeric_binary_ufuncs = [np.float_power, np.power,
105108
# these are not implemented for complex inputs
106-
# np.hypot, np.arctan2, np.copysign,
107-
# np.floor_divide, np.fmax, np.fmin, np.fmod,
108-
# np.heaviside, np.logaddexp, np.logaddexp2, np.maximum, np.minimum,
109+
no_complex = [np.floor_divide, np.hypot, np.arctan2, np.copysign, np.fmax,
110+
np.fmin, np.fmod, np.heaviside, np.logaddexp, np.logaddexp2,
111+
np.maximum, np.minimum,
109112
]
110113

111114
parametrize_binary_ufuncs = pytest.mark.parametrize(
112-
'ufunc', ufuncs_with_dunders + numeric_binary_ufuncs)
115+
'ufunc', ufuncs_with_dunders + numeric_binary_ufuncs + no_complex)
116+
113117

114118

115119
# TODO: these snowflakes need special handling
@@ -159,6 +163,9 @@ def test_xy_and_out_casting(self, ufunc, casting, out_dtype):
159163
x, y = self.get_xy(ufunc)
160164
out = np.empty_like(x, dtype=out_dtype)
161165

166+
if ufunc in no_complex and np.issubdtype(out_dtype, np.complexfloating):
167+
pytest.skip(f'{ufunc} does not accept complex.')
168+
162169
can_cast_x = np.can_cast(x, out_dtype, casting=casting)
163170
can_cast_y = np.can_cast(y, out_dtype, casting=casting)
164171

@@ -208,12 +215,13 @@ def test_basic(self, ufunc, op, iop):
208215
assert_equal(op(a.tolist(), a), ufunc(a, a.tolist()))
209216

210217
# __iadd__
211-
a0 = np.array([1, 2, 3])
218+
a0 = np.array([2, 4, 6])
212219
a = a0.copy()
220+
213221
iop(a, 2) # modifies a in-place
214222
assert_equal(a, op(a0, 2))
215223

216-
a0 = np.array([1, 2, 3])
224+
a0 = np.array([2, 4, 6])
217225
a = a0.copy()
218226
iop(a, a)
219227
assert_equal(a, op(a0, a0))
@@ -225,6 +233,9 @@ def test_other_scalar(self, ufunc, op, iop, other_dtype):
225233
a = np.array([1, 2, 3])
226234
b = other_dtype(3)
227235

236+
if ufunc in no_complex and issubclass(other_dtype, np.complexfloating):
237+
pytest.skip(f'{ufunc} does not accept complex.')
238+
228239
# __op__
229240
result = op(a, b)
230241
assert_equal(result, ufunc(a, b))
@@ -253,9 +264,13 @@ def test_other_scalar(self, ufunc, op, iop, other_dtype):
253264
@pytest.mark.parametrize("other_dtype", dtypes_numeric)
254265
def test_other_array(self, ufunc, op, iop, other_dtype):
255266
"""Test op/iop/rop when the other argument is an array of a different dtype."""
256-
# __op__
257267
a = np.array([1, 2, 3])
258268
b = np.array([5, 6, 7], dtype=other_dtype)
269+
270+
if ufunc in no_complex and issubclass(other_dtype, np.complexfloating):
271+
pytest.skip(f'{ufunc} does not accept complex.')
272+
273+
# __op__
259274
result = op(a, b)
260275
assert_equal(result, ufunc(a, b))
261276
assert result.dtype == np.result_type(a, b)

0 commit comments

Comments
 (0)