diff --git a/torch_np/_helpers.py b/torch_np/_helpers.py index f20e478f..e7d863c8 100644 --- a/torch_np/_helpers.py +++ b/torch_np/_helpers.py @@ -94,6 +94,8 @@ def to_tensors_or_none(*inputs): def _outer(x, y): + from ._ndarray import asarray + x_tensor, y_tensor = to_tensors(x, y) result = torch.outer(x_tensor, y_tensor) return asarray(result) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index c5ed3b4a..746fe4c5 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -182,7 +182,11 @@ def __iadd__(self, other): return _binary_ufuncs.add(self, other, out=self) # sub, self - other - __sub__ = __rsub__ = _binary_ufuncs.subtract + __sub__ = _binary_ufuncs.subtract + + # XXX: generate a function just for this? AND other non-commutative ops. + def __rsub__(self, other): + return _binary_ufuncs.subtract(other, self) def __isub__(self, other): return _binary_ufuncs.subtract(self, other, out=self) @@ -194,13 +198,19 @@ def __imul__(self, other): return _binary_ufuncs.multiply(self, other, out=self) # div, self / other - __truediv__ = __rtruediv__ = _binary_ufuncs.divide + __truediv__ = _binary_ufuncs.divide + + def __rtruediv__(self, other): + return _binary_ufuncs.divide(other, self) def __itruediv__(self, other): return _binary_ufuncs.divide(self, other, out=self) # floordiv, self // other - __floordiv__ = __rfloordiv__ = _binary_ufuncs.floor_divide + __floordiv__ = _binary_ufuncs.floor_divide + + def __rfloordiv__(self, other): + return _binary_ufuncs.floor_divide(other, self) def __ifloordiv__(self, other): return _binary_ufuncs.floor_divide(self, other, out=self) @@ -208,6 +218,9 @@ def __ifloordiv__(self, other): # power, self**exponent __pow__ = __rpow__ = _binary_ufuncs.float_power + def __rpow__(self, exponent): + return _binary_ufuncs.float_power(exponent, self) + def __ipow__(self, exponent): return _binary_ufuncs.float_power(self, exponent, out=self) diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 64cd83a0..755efd01 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -21,6 +21,7 @@ from . import _dtypes, _helpers, _decorators # isort: skip # XXX + # Things to decide on (punt for now) # # 1. Q: What are the return types of wrapper functions: plain torch.Tensors or @@ -273,12 +274,15 @@ def ones_like(a, dtype=None, order="K", subok=False, shape=None): return result -# XXX: dtype=float @_decorators.dtype_to_torch -def zeros(shape, dtype=float, order="C", *, like=None): +def zeros(shape, dtype=None, order="C", *, like=None): _util.subok_not_ok(like) if order != "C": raise NotImplementedError + if dtype is None: + from ._detail._scalar_types import default_float_type + + dtype = default_float_type.torch_dtype return asarray(torch.zeros(shape, dtype=dtype)) diff --git a/torch_np/tests/test_ufuncs_basic.py b/torch_np/tests/test_ufuncs_basic.py index 7bbb7294..181f1d04 100644 --- a/torch_np/tests/test_ufuncs_basic.py +++ b/torch_np/tests/test_ufuncs_basic.py @@ -233,7 +233,7 @@ def test_basic(self, ufunc, op, iop): # __radd__ a = np.array([1, 2, 3]) - assert_equal(op(1, a), ufunc(a, 1)) + assert_equal(op(1, a), ufunc(1, a)) assert_equal(op(a.tolist(), a), ufunc(a, a.tolist())) # __iadd__