Skip to content

Commit 3a37806

Browse files
authored
Merge pull request #44 from Quansight-Labs/op_rop
Fix __rop__ dunders
2 parents ff94fdb + 7ac0caf commit 3a37806

File tree

4 files changed

+25
-6
lines changed

4 files changed

+25
-6
lines changed

torch_np/_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def to_tensors_or_none(*inputs):
9494

9595

9696
def _outer(x, y):
97+
from ._ndarray import asarray
98+
9799
x_tensor, y_tensor = to_tensors(x, y)
98100
result = torch.outer(x_tensor, y_tensor)
99101
return asarray(result)

torch_np/_ndarray.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,11 @@ def __iadd__(self, other):
182182
return _binary_ufuncs.add(self, other, out=self)
183183

184184
# sub, self - other
185-
__sub__ = __rsub__ = _binary_ufuncs.subtract
185+
__sub__ = _binary_ufuncs.subtract
186+
187+
# XXX: generate a function just for this? AND other non-commutative ops.
188+
def __rsub__(self, other):
189+
return _binary_ufuncs.subtract(other, self)
186190

187191
def __isub__(self, other):
188192
return _binary_ufuncs.subtract(self, other, out=self)
@@ -194,20 +198,29 @@ def __imul__(self, other):
194198
return _binary_ufuncs.multiply(self, other, out=self)
195199

196200
# div, self / other
197-
__truediv__ = __rtruediv__ = _binary_ufuncs.divide
201+
__truediv__ = _binary_ufuncs.divide
202+
203+
def __rtruediv__(self, other):
204+
return _binary_ufuncs.divide(other, self)
198205

199206
def __itruediv__(self, other):
200207
return _binary_ufuncs.divide(self, other, out=self)
201208

202209
# floordiv, self // other
203-
__floordiv__ = __rfloordiv__ = _binary_ufuncs.floor_divide
210+
__floordiv__ = _binary_ufuncs.floor_divide
211+
212+
def __rfloordiv__(self, other):
213+
return _binary_ufuncs.floor_divide(other, self)
204214

205215
def __ifloordiv__(self, other):
206216
return _binary_ufuncs.floor_divide(self, other, out=self)
207217

208218
# power, self**exponent
209219
__pow__ = __rpow__ = _binary_ufuncs.float_power
210220

221+
def __rpow__(self, exponent):
222+
return _binary_ufuncs.float_power(exponent, self)
223+
211224
def __ipow__(self, exponent):
212225
return _binary_ufuncs.float_power(self, exponent, out=self)
213226

torch_np/_wrapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from . import _dtypes, _helpers, _decorators # isort: skip # XXX
2323

24+
2425
# Things to decide on (punt for now)
2526
#
2627
# 1. Q: What are the return types of wrapper functions: plain torch.Tensors or
@@ -273,12 +274,15 @@ def ones_like(a, dtype=None, order="K", subok=False, shape=None):
273274
return result
274275

275276

276-
# XXX: dtype=float
277277
@_decorators.dtype_to_torch
278-
def zeros(shape, dtype=float, order="C", *, like=None):
278+
def zeros(shape, dtype=None, order="C", *, like=None):
279279
_util.subok_not_ok(like)
280280
if order != "C":
281281
raise NotImplementedError
282+
if dtype is None:
283+
from ._detail._scalar_types import default_float_type
284+
285+
dtype = default_float_type.torch_dtype
282286
return asarray(torch.zeros(shape, dtype=dtype))
283287

284288

torch_np/tests/test_ufuncs_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def test_basic(self, ufunc, op, iop):
233233

234234
# __radd__
235235
a = np.array([1, 2, 3])
236-
assert_equal(op(1, a), ufunc(a, 1))
236+
assert_equal(op(1, a), ufunc(1, a))
237237
assert_equal(op(a.tolist(), a), ufunc(a, a.tolist()))
238238

239239
# __iadd__

0 commit comments

Comments
 (0)