Skip to content

Simplify registartion of methods in ndarray #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 21, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 109 additions & 168 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,85 @@ def __getitem__(self, key):
raise KeyError(f"No flag key '{key}'")


def create_method(fn, name=None):
name = name or fn.__name__

def f(*args, **kwargs):
return fn(*args, **kwargs)

f.__name__ = name
f.__qualname__ = f"ndarray.{name}"
return f


# Map ndarray.name_method -> np.name_func
# If name_func == None, it means that name_method == name_func
methods = {
"clip": None,
"flatten": "_flatten",
"nonzero": None,
"repeat": None,
"round": None,
"squeeze": None,
"swapaxes": None,
"ravel": None,
# linalg
"diagonal": None,
"dot": None,
"trace": None,
# sorting
"argsort": None,
"searchsorted": None,
# reductions
"argmax": None,
"argmin": None,
"any": None,
"all": None,
"max": None,
"min": None,
"ptp": None,
"sum": None,
"prod": None,
"mean": None,
"var": None,
"std": None,
# scans
"cumsum": None,
"cumprod": None,
# advanced indexing
"take": None,
}

dunder = {
"abs": "absolute",
"invert": None,
"pos": "positive",
"neg": "negative",
"gt": "greater",
"lt": "less",
"ge": "greater_equal",
"le": "less_equal",
}

# dunder methods with right-looking and in-place variants
ri_dunder = {
"add": None,
"sub": "subtract",
"mul": "multiply",
"truediv": "divide",
"floordiv": "floor_divide",
"divmod": None,
"pow": "float_power",
"mod": "remainder",
"and": "bitwise_and",
"or": "bitwise_or",
"xor": "bitwise_xor",
"lshift": "left_shift",
"rshift": "right_shift",
"matmul": None,
}


##################### ndarray class ###########################


Expand All @@ -72,6 +151,31 @@ def __init__(self, t=None):
"either array(...) or zeros/empty(...)"
)

# Register NumPy functions as methods
for method, name in methods.items():
fn = getattr(_funcs, name or method)
vars()[method] = create_method(fn, method)

# Regular methods but coming from ufuncs
conj = create_method(_ufuncs.conjugate, "conj")
conjugate = create_method(_ufuncs.conjugate)

for method, name in dunder.items():
fn = getattr(_ufuncs, name or method)
method = f"__{method}__"
vars()[method] = create_method(fn, method)

for method, name in ri_dunder.items():
fn = getattr(_ufuncs, name or method)
plain = f"__{method}__"
vars()[plain] = create_method(fn, plain)
rvar = f"__r{method}__"
vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
ivar = f"__i{method}__"
vars()[ivar] = create_method(
lambda self, other, fn=fn: fn(self, other, out=self), ivar
)

@property
def shape(self):
return tuple(self.tensor.shape)
Expand Down Expand Up @@ -100,18 +204,10 @@ def itemsize(self):
@property
def flags(self):
# Note contiguous in torch is assumed C-style

# check if F contiguous
from itertools import accumulate

f_strides = tuple(accumulate(list(self.tensor.shape), func=lambda x, y: x * y))
f_strides = (1,) + f_strides[:-1]
is_f_contiguous = f_strides == self.tensor.stride()

return Flags(
{
"C_CONTIGUOUS": self.tensor.is_contiguous(),
"F_CONTIGUOUS": is_f_contiguous,
"F_CONTIGUOUS": self.T.tensor.is_contiguous(),
"OWNDATA": self.tensor._base is None,
"WRITEABLE": True, # pytorch does not have readonly tensors
}
Expand Down Expand Up @@ -145,14 +241,11 @@ def imag(self):
def imag(self, value):
self.tensor.imag = asarray(value).tensor

round = _funcs.round

# ctors
def astype(self, dtype):
newt = ndarray()
torch_dtype = _dtypes.dtype(dtype).torch_dtype
newt.tensor = self.tensor.to(torch_dtype)
return newt
t = self.tensor.to(torch_dtype)
return ndarray(t)

def copy(self, order="C"):
if order != "C":
Expand Down Expand Up @@ -182,7 +275,7 @@ def __str__(self):
.replace("dtype=torch.", "dtype=")
)

__repr__ = __str__
__repr__ = create_method(__str__)

### comparisons ###
def __eq__(self, other):
Expand All @@ -201,11 +294,6 @@ def __ne__(self, other):
falsy = torch.full(self.shape, fill_value=True, dtype=bool)
return asarray(falsy)

__gt__ = _ufuncs.greater
__lt__ = _ufuncs.less
__ge__ = _ufuncs.greater_equal
__le__ = _ufuncs.less_equal

def __bool__(self):
try:
return bool(self.tensor)
Expand Down Expand Up @@ -247,117 +335,7 @@ def is_integer(self):
def __len__(self):
return self.tensor.shape[0]

### arithmetic ###

# add, self + other
__add__ = __radd__ = _ufuncs.add

def __iadd__(self, other):
return _ufuncs.add(self, other, out=self)

# sub, self - other
__sub__ = _ufuncs.subtract

# XXX: generate a function just for this? AND other non-commutative ops.
def __rsub__(self, other):
return _ufuncs.subtract(other, self)

def __isub__(self, other):
return _ufuncs.subtract(self, other, out=self)

# mul, self * other
__mul__ = __rmul__ = _ufuncs.multiply

def __imul__(self, other):
return _ufuncs.multiply(self, other, out=self)

# div, self / other
__truediv__ = _ufuncs.divide

def __rtruediv__(self, other):
return _ufuncs.divide(other, self)

def __itruediv__(self, other):
return _ufuncs.divide(self, other, out=self)

# floordiv, self // other
__floordiv__ = _ufuncs.floor_divide

def __rfloordiv__(self, other):
return _ufuncs.floor_divide(other, self)

def __ifloordiv__(self, other):
return _ufuncs.floor_divide(self, other, out=self)

__divmod__ = _ufuncs.divmod

# power, self**exponent
__pow__ = __rpow__ = _ufuncs.float_power

def __rpow__(self, exponent):
return _ufuncs.float_power(exponent, self)

def __ipow__(self, exponent):
return _ufuncs.float_power(self, exponent, out=self)

# remainder, self % other
__mod__ = __rmod__ = _ufuncs.remainder

def __imod__(self, other):
return _ufuncs.remainder(self, other, out=self)

# bitwise ops
# and, self & other
__and__ = __rand__ = _ufuncs.bitwise_and

def __iand__(self, other):
return _ufuncs.bitwise_and(self, other, out=self)

# or, self | other
__or__ = __ror__ = _ufuncs.bitwise_or

def __ior__(self, other):
return _ufuncs.bitwise_or(self, other, out=self)

# xor, self ^ other
__xor__ = __rxor__ = _ufuncs.bitwise_xor

def __ixor__(self, other):
return _ufuncs.bitwise_xor(self, other, out=self)

# bit shifts
__lshift__ = __rlshift__ = _ufuncs.left_shift

def __ilshift__(self, other):
return _ufuncs.left_shift(self, other, out=self)

__rshift__ = __rrshift__ = _ufuncs.right_shift

def __irshift__(self, other):
return _ufuncs.right_shift(self, other, out=self)

__matmul__ = _ufuncs.matmul

def __rmatmul__(self, other):
return _ufuncs.matmul(other, self)

def __imatmul__(self, other):
return _ufuncs.matmul(self, other, out=self)

# unary ops
__invert__ = _ufuncs.invert
__abs__ = _ufuncs.absolute
__pos__ = _ufuncs.positive
__neg__ = _ufuncs.negative

conjugate = _ufuncs.conjugate
conj = conjugate

### methods to match namespace functions

squeeze = _funcs.squeeze
swapaxes = _funcs.swapaxes

def transpose(self, *axes):
# np.transpose(arr, axis=None) but arr.transpose(*axes)
return _funcs.transpose(self, axes)
Expand All @@ -366,51 +344,18 @@ def reshape(self, *shape, order="C"):
# arr.reshape(shape) and arr.reshape(*shape)
return _funcs.reshape(self, shape, order=order)

ravel = _funcs.ravel
flatten = _funcs._flatten

def resize(self, *new_shape, refcheck=False):
# ndarray.resize works in-place (may cause a reallocation though)
self.tensor = _funcs_impl._ndarray_resize(
self.tensor, new_shape, refcheck=refcheck
)

nonzero = _funcs.nonzero
clip = _funcs.clip
repeat = _funcs.repeat

diagonal = _funcs.diagonal
trace = _funcs.trace
dot = _funcs.dot

### sorting ###

def sort(self, axis=-1, kind=None, order=None):
# ndarray.sort works in-place
_funcs.copyto(self, _funcs.sort(self, axis, kind, order))

argsort = _funcs.argsort
searchsorted = _funcs.searchsorted

### reductions ###
argmax = _funcs.argmax
argmin = _funcs.argmin

any = _funcs.any
all = _funcs.all
max = _funcs.max
min = _funcs.min
ptp = _funcs.ptp

sum = _funcs.sum
prod = _funcs.prod
mean = _funcs.mean
var = _funcs.var
std = _funcs.std

cumsum = _funcs.cumsum
cumprod = _funcs.cumprod

### indexing ###
def item(self, *args):
# Mimic NumPy's implementation with three special cases (no arguments,
Expand Down Expand Up @@ -444,8 +389,6 @@ def __setitem__(self, index, value):
value = _helpers.ndarrays_to_tensors(value)
return self.tensor.__setitem__(index, value)

take = _funcs.take


# This is the ideally the only place which talks to ndarray directly.
# The rest goes through asarray (preferred) or array.
Expand Down Expand Up @@ -484,9 +427,7 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
return ndarray(tensor)


def asarray(a, dtype=None, order=None, *, like=None):
if order is None:
order = "K"
def asarray(a, dtype=None, order="K", *, like=None):
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)


Expand Down