diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 15659e65..e833c4e7 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -57,6 +57,84 @@ def __getitem__(self, key): raise KeyError(f"No flag key '{key}'") +def create_method(fn, name=None): + name = name or fn.__name__ + + def f(*args, **kwargs): + return fn(*args, **kwargs) + + f.__name__ = name + f.__qualname__ = f"ndarray.{name}" + return f + + +# Map ndarray.name_method -> np.name_func +# If name_func == None, it means that name_method == name_func +methods = { + "clip": None, + "flatten": "_flatten", + "nonzero": None, + "repeat": None, + "round": None, + "squeeze": None, + "swapaxes": None, + "ravel": None, + # linalg + "diagonal": None, + "dot": None, + "trace": None, + # sorting + "argsort": None, + "searchsorted": None, + # reductions + "argmax": None, + "argmin": None, + "any": None, + "all": None, + "max": None, + "min": None, + "ptp": None, + "sum": None, + "prod": None, + "mean": None, + "var": None, + "std": None, + # scans + "cumsum": None, + "cumprod": None, + # advanced indexing + "take": None, +} + +dunder = { + "abs": "absolute", + "invert": None, + "pos": "positive", + "neg": "negative", + "gt": "greater", + "lt": "less", + "ge": "greater_equal", + "le": "less_equal", +} + +# dunder methods with right-looking and in-place variants +ri_dunder = { + "add": None, + "sub": "subtract", + "mul": "multiply", + "truediv": "divide", + "floordiv": "floor_divide", + "pow": "float_power", + "mod": "remainder", + "and": "bitwise_and", + "or": "bitwise_or", + "xor": "bitwise_xor", + "lshift": "left_shift", + "rshift": "right_shift", + "matmul": None, +} + + ##################### ndarray class ########################### @@ -72,6 +150,37 @@ def __init__(self, t=None): "either array(...) or zeros/empty(...)" ) + # Register NumPy functions as methods + for method, name in methods.items(): + fn = getattr(_funcs, name or method) + vars()[method] = create_method(fn, method) + + # Regular methods but coming from ufuncs + conj = create_method(_ufuncs.conjugate, "conj") + conjugate = create_method(_ufuncs.conjugate) + + for method, name in dunder.items(): + fn = getattr(_ufuncs, name or method) + method = f"__{method}__" + vars()[method] = create_method(fn, method) + + for method, name in ri_dunder.items(): + fn = getattr(_ufuncs, name or method) + plain = f"__{method}__" + vars()[plain] = create_method(fn, plain) + rvar = f"__r{method}__" + vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar) + ivar = f"__i{method}__" + vars()[ivar] = create_method( + lambda self, other, fn=fn: fn(self, other, out=self), ivar + ) + + # There's no __idivmod__ + __divmod__ = create_method(_ufuncs.divmod, "__divmod__") + __rdivmod__ = create_method( + lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__" + ) + @property def shape(self): return tuple(self.tensor.shape) @@ -100,18 +209,10 @@ def itemsize(self): @property def flags(self): # Note contiguous in torch is assumed C-style - - # check if F contiguous - from itertools import accumulate - - f_strides = tuple(accumulate(list(self.tensor.shape), func=lambda x, y: x * y)) - f_strides = (1,) + f_strides[:-1] - is_f_contiguous = f_strides == self.tensor.stride() - return Flags( { "C_CONTIGUOUS": self.tensor.is_contiguous(), - "F_CONTIGUOUS": is_f_contiguous, + "F_CONTIGUOUS": self.T.tensor.is_contiguous(), "OWNDATA": self.tensor._base is None, "WRITEABLE": True, # pytorch does not have readonly tensors } @@ -145,14 +246,11 @@ def imag(self): def imag(self, value): self.tensor.imag = asarray(value).tensor - round = _funcs.round - # ctors def astype(self, dtype): - newt = ndarray() torch_dtype = _dtypes.dtype(dtype).torch_dtype - newt.tensor = self.tensor.to(torch_dtype) - return newt + t = self.tensor.to(torch_dtype) + return ndarray(t) def copy(self, order="C"): if order != "C": @@ -182,7 +280,7 @@ def __str__(self): .replace("dtype=torch.", "dtype=") ) - __repr__ = __str__ + __repr__ = create_method(__str__) ### comparisons ### def __eq__(self, other): @@ -201,11 +299,6 @@ def __ne__(self, other): falsy = torch.full(self.shape, fill_value=True, dtype=bool) return asarray(falsy) - __gt__ = _ufuncs.greater - __lt__ = _ufuncs.less - __ge__ = _ufuncs.greater_equal - __le__ = _ufuncs.less_equal - def __bool__(self): try: return bool(self.tensor) @@ -247,117 +340,7 @@ def is_integer(self): def __len__(self): return self.tensor.shape[0] - ### arithmetic ### - - # add, self + other - __add__ = __radd__ = _ufuncs.add - - def __iadd__(self, other): - return _ufuncs.add(self, other, out=self) - - # sub, self - other - __sub__ = _ufuncs.subtract - - # XXX: generate a function just for this? AND other non-commutative ops. - def __rsub__(self, other): - return _ufuncs.subtract(other, self) - - def __isub__(self, other): - return _ufuncs.subtract(self, other, out=self) - - # mul, self * other - __mul__ = __rmul__ = _ufuncs.multiply - - def __imul__(self, other): - return _ufuncs.multiply(self, other, out=self) - - # div, self / other - __truediv__ = _ufuncs.divide - - def __rtruediv__(self, other): - return _ufuncs.divide(other, self) - - def __itruediv__(self, other): - return _ufuncs.divide(self, other, out=self) - - # floordiv, self // other - __floordiv__ = _ufuncs.floor_divide - - def __rfloordiv__(self, other): - return _ufuncs.floor_divide(other, self) - - def __ifloordiv__(self, other): - return _ufuncs.floor_divide(self, other, out=self) - - __divmod__ = _ufuncs.divmod - - # power, self**exponent - __pow__ = __rpow__ = _ufuncs.float_power - - def __rpow__(self, exponent): - return _ufuncs.float_power(exponent, self) - - def __ipow__(self, exponent): - return _ufuncs.float_power(self, exponent, out=self) - - # remainder, self % other - __mod__ = __rmod__ = _ufuncs.remainder - - def __imod__(self, other): - return _ufuncs.remainder(self, other, out=self) - - # bitwise ops - # and, self & other - __and__ = __rand__ = _ufuncs.bitwise_and - - def __iand__(self, other): - return _ufuncs.bitwise_and(self, other, out=self) - - # or, self | other - __or__ = __ror__ = _ufuncs.bitwise_or - - def __ior__(self, other): - return _ufuncs.bitwise_or(self, other, out=self) - - # xor, self ^ other - __xor__ = __rxor__ = _ufuncs.bitwise_xor - - def __ixor__(self, other): - return _ufuncs.bitwise_xor(self, other, out=self) - - # bit shifts - __lshift__ = __rlshift__ = _ufuncs.left_shift - - def __ilshift__(self, other): - return _ufuncs.left_shift(self, other, out=self) - - __rshift__ = __rrshift__ = _ufuncs.right_shift - - def __irshift__(self, other): - return _ufuncs.right_shift(self, other, out=self) - - __matmul__ = _ufuncs.matmul - - def __rmatmul__(self, other): - return _ufuncs.matmul(other, self) - - def __imatmul__(self, other): - return _ufuncs.matmul(self, other, out=self) - - # unary ops - __invert__ = _ufuncs.invert - __abs__ = _ufuncs.absolute - __pos__ = _ufuncs.positive - __neg__ = _ufuncs.negative - - conjugate = _ufuncs.conjugate - conj = conjugate - ### methods to match namespace functions - - squeeze = _funcs.squeeze - swapaxes = _funcs.swapaxes - def transpose(self, *axes): # np.transpose(arr, axis=None) but arr.transpose(*axes) return _funcs.transpose(self, axes) @@ -366,51 +349,18 @@ def reshape(self, *shape, order="C"): # arr.reshape(shape) and arr.reshape(*shape) return _funcs.reshape(self, shape, order=order) - ravel = _funcs.ravel - flatten = _funcs._flatten - def resize(self, *new_shape, refcheck=False): # ndarray.resize works in-place (may cause a reallocation though) self.tensor = _funcs_impl._ndarray_resize( self.tensor, new_shape, refcheck=refcheck ) - nonzero = _funcs.nonzero - clip = _funcs.clip - repeat = _funcs.repeat - - diagonal = _funcs.diagonal - trace = _funcs.trace - dot = _funcs.dot - ### sorting ### def sort(self, axis=-1, kind=None, order=None): # ndarray.sort works in-place _funcs.copyto(self, _funcs.sort(self, axis, kind, order)) - argsort = _funcs.argsort - searchsorted = _funcs.searchsorted - - ### reductions ### - argmax = _funcs.argmax - argmin = _funcs.argmin - - any = _funcs.any - all = _funcs.all - max = _funcs.max - min = _funcs.min - ptp = _funcs.ptp - - sum = _funcs.sum - prod = _funcs.prod - mean = _funcs.mean - var = _funcs.var - std = _funcs.std - - cumsum = _funcs.cumsum - cumprod = _funcs.cumprod - ### indexing ### def item(self, *args): # Mimic NumPy's implementation with three special cases (no arguments, @@ -444,8 +394,6 @@ def __setitem__(self, index, value): value = _helpers.ndarrays_to_tensors(value) return self.tensor.__setitem__(index, value) - take = _funcs.take - # This is the ideally the only place which talks to ndarray directly. # The rest goes through asarray (preferred) or array. @@ -484,9 +432,7 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N return ndarray(tensor) -def asarray(a, dtype=None, order=None, *, like=None): - if order is None: - order = "K" +def asarray(a, dtype=None, order="K", *, like=None): return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)