Skip to content

Commit 5be250a

Browse files
authored
Merge pull request #118 from Quansight-Labs/dunder
Simplify registartion of methods in ndarray
2 parents 134db55 + fcf6b48 commit 5be250a

File tree

1 file changed

+114
-168
lines changed

1 file changed

+114
-168
lines changed

torch_np/_ndarray.py

Lines changed: 114 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,84 @@ def __getitem__(self, key):
5757
raise KeyError(f"No flag key '{key}'")
5858

5959

60+
def create_method(fn, name=None):
61+
name = name or fn.__name__
62+
63+
def f(*args, **kwargs):
64+
return fn(*args, **kwargs)
65+
66+
f.__name__ = name
67+
f.__qualname__ = f"ndarray.{name}"
68+
return f
69+
70+
71+
# Map ndarray.name_method -> np.name_func
72+
# If name_func == None, it means that name_method == name_func
73+
methods = {
74+
"clip": None,
75+
"flatten": "_flatten",
76+
"nonzero": None,
77+
"repeat": None,
78+
"round": None,
79+
"squeeze": None,
80+
"swapaxes": None,
81+
"ravel": None,
82+
# linalg
83+
"diagonal": None,
84+
"dot": None,
85+
"trace": None,
86+
# sorting
87+
"argsort": None,
88+
"searchsorted": None,
89+
# reductions
90+
"argmax": None,
91+
"argmin": None,
92+
"any": None,
93+
"all": None,
94+
"max": None,
95+
"min": None,
96+
"ptp": None,
97+
"sum": None,
98+
"prod": None,
99+
"mean": None,
100+
"var": None,
101+
"std": None,
102+
# scans
103+
"cumsum": None,
104+
"cumprod": None,
105+
# advanced indexing
106+
"take": None,
107+
}
108+
109+
dunder = {
110+
"abs": "absolute",
111+
"invert": None,
112+
"pos": "positive",
113+
"neg": "negative",
114+
"gt": "greater",
115+
"lt": "less",
116+
"ge": "greater_equal",
117+
"le": "less_equal",
118+
}
119+
120+
# dunder methods with right-looking and in-place variants
121+
ri_dunder = {
122+
"add": None,
123+
"sub": "subtract",
124+
"mul": "multiply",
125+
"truediv": "divide",
126+
"floordiv": "floor_divide",
127+
"pow": "float_power",
128+
"mod": "remainder",
129+
"and": "bitwise_and",
130+
"or": "bitwise_or",
131+
"xor": "bitwise_xor",
132+
"lshift": "left_shift",
133+
"rshift": "right_shift",
134+
"matmul": None,
135+
}
136+
137+
60138
##################### ndarray class ###########################
61139

62140

@@ -72,6 +150,37 @@ def __init__(self, t=None):
72150
"either array(...) or zeros/empty(...)"
73151
)
74152

153+
# Register NumPy functions as methods
154+
for method, name in methods.items():
155+
fn = getattr(_funcs, name or method)
156+
vars()[method] = create_method(fn, method)
157+
158+
# Regular methods but coming from ufuncs
159+
conj = create_method(_ufuncs.conjugate, "conj")
160+
conjugate = create_method(_ufuncs.conjugate)
161+
162+
for method, name in dunder.items():
163+
fn = getattr(_ufuncs, name or method)
164+
method = f"__{method}__"
165+
vars()[method] = create_method(fn, method)
166+
167+
for method, name in ri_dunder.items():
168+
fn = getattr(_ufuncs, name or method)
169+
plain = f"__{method}__"
170+
vars()[plain] = create_method(fn, plain)
171+
rvar = f"__r{method}__"
172+
vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
173+
ivar = f"__i{method}__"
174+
vars()[ivar] = create_method(
175+
lambda self, other, fn=fn: fn(self, other, out=self), ivar
176+
)
177+
178+
# There's no __idivmod__
179+
__divmod__ = create_method(_ufuncs.divmod, "__divmod__")
180+
__rdivmod__ = create_method(
181+
lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
182+
)
183+
75184
@property
76185
def shape(self):
77186
return tuple(self.tensor.shape)
@@ -100,18 +209,10 @@ def itemsize(self):
100209
@property
101210
def flags(self):
102211
# Note contiguous in torch is assumed C-style
103-
104-
# check if F contiguous
105-
from itertools import accumulate
106-
107-
f_strides = tuple(accumulate(list(self.tensor.shape), func=lambda x, y: x * y))
108-
f_strides = (1,) + f_strides[:-1]
109-
is_f_contiguous = f_strides == self.tensor.stride()
110-
111212
return Flags(
112213
{
113214
"C_CONTIGUOUS": self.tensor.is_contiguous(),
114-
"F_CONTIGUOUS": is_f_contiguous,
215+
"F_CONTIGUOUS": self.T.tensor.is_contiguous(),
115216
"OWNDATA": self.tensor._base is None,
116217
"WRITEABLE": True, # pytorch does not have readonly tensors
117218
}
@@ -145,14 +246,11 @@ def imag(self):
145246
def imag(self, value):
146247
self.tensor.imag = asarray(value).tensor
147248

148-
round = _funcs.round
149-
150249
# ctors
151250
def astype(self, dtype):
152-
newt = ndarray()
153251
torch_dtype = _dtypes.dtype(dtype).torch_dtype
154-
newt.tensor = self.tensor.to(torch_dtype)
155-
return newt
252+
t = self.tensor.to(torch_dtype)
253+
return ndarray(t)
156254

157255
def copy(self, order="C"):
158256
if order != "C":
@@ -182,7 +280,7 @@ def __str__(self):
182280
.replace("dtype=torch.", "dtype=")
183281
)
184282

185-
__repr__ = __str__
283+
__repr__ = create_method(__str__)
186284

187285
### comparisons ###
188286
def __eq__(self, other):
@@ -201,11 +299,6 @@ def __ne__(self, other):
201299
falsy = torch.full(self.shape, fill_value=True, dtype=bool)
202300
return asarray(falsy)
203301

204-
__gt__ = _ufuncs.greater
205-
__lt__ = _ufuncs.less
206-
__ge__ = _ufuncs.greater_equal
207-
__le__ = _ufuncs.less_equal
208-
209302
def __bool__(self):
210303
try:
211304
return bool(self.tensor)
@@ -247,117 +340,7 @@ def is_integer(self):
247340
def __len__(self):
248341
return self.tensor.shape[0]
249342

250-
### arithmetic ###
251-
252-
# add, self + other
253-
__add__ = __radd__ = _ufuncs.add
254-
255-
def __iadd__(self, other):
256-
return _ufuncs.add(self, other, out=self)
257-
258-
# sub, self - other
259-
__sub__ = _ufuncs.subtract
260-
261-
# XXX: generate a function just for this? AND other non-commutative ops.
262-
def __rsub__(self, other):
263-
return _ufuncs.subtract(other, self)
264-
265-
def __isub__(self, other):
266-
return _ufuncs.subtract(self, other, out=self)
267-
268-
# mul, self * other
269-
__mul__ = __rmul__ = _ufuncs.multiply
270-
271-
def __imul__(self, other):
272-
return _ufuncs.multiply(self, other, out=self)
273-
274-
# div, self / other
275-
__truediv__ = _ufuncs.divide
276-
277-
def __rtruediv__(self, other):
278-
return _ufuncs.divide(other, self)
279-
280-
def __itruediv__(self, other):
281-
return _ufuncs.divide(self, other, out=self)
282-
283-
# floordiv, self // other
284-
__floordiv__ = _ufuncs.floor_divide
285-
286-
def __rfloordiv__(self, other):
287-
return _ufuncs.floor_divide(other, self)
288-
289-
def __ifloordiv__(self, other):
290-
return _ufuncs.floor_divide(self, other, out=self)
291-
292-
__divmod__ = _ufuncs.divmod
293-
294-
# power, self**exponent
295-
__pow__ = __rpow__ = _ufuncs.float_power
296-
297-
def __rpow__(self, exponent):
298-
return _ufuncs.float_power(exponent, self)
299-
300-
def __ipow__(self, exponent):
301-
return _ufuncs.float_power(self, exponent, out=self)
302-
303-
# remainder, self % other
304-
__mod__ = __rmod__ = _ufuncs.remainder
305-
306-
def __imod__(self, other):
307-
return _ufuncs.remainder(self, other, out=self)
308-
309-
# bitwise ops
310-
# and, self & other
311-
__and__ = __rand__ = _ufuncs.bitwise_and
312-
313-
def __iand__(self, other):
314-
return _ufuncs.bitwise_and(self, other, out=self)
315-
316-
# or, self | other
317-
__or__ = __ror__ = _ufuncs.bitwise_or
318-
319-
def __ior__(self, other):
320-
return _ufuncs.bitwise_or(self, other, out=self)
321-
322-
# xor, self ^ other
323-
__xor__ = __rxor__ = _ufuncs.bitwise_xor
324-
325-
def __ixor__(self, other):
326-
return _ufuncs.bitwise_xor(self, other, out=self)
327-
328-
# bit shifts
329-
__lshift__ = __rlshift__ = _ufuncs.left_shift
330-
331-
def __ilshift__(self, other):
332-
return _ufuncs.left_shift(self, other, out=self)
333-
334-
__rshift__ = __rrshift__ = _ufuncs.right_shift
335-
336-
def __irshift__(self, other):
337-
return _ufuncs.right_shift(self, other, out=self)
338-
339-
__matmul__ = _ufuncs.matmul
340-
341-
def __rmatmul__(self, other):
342-
return _ufuncs.matmul(other, self)
343-
344-
def __imatmul__(self, other):
345-
return _ufuncs.matmul(self, other, out=self)
346-
347-
# unary ops
348-
__invert__ = _ufuncs.invert
349-
__abs__ = _ufuncs.absolute
350-
__pos__ = _ufuncs.positive
351-
__neg__ = _ufuncs.negative
352-
353-
conjugate = _ufuncs.conjugate
354-
conj = conjugate
355-
356343
### methods to match namespace functions
357-
358-
squeeze = _funcs.squeeze
359-
swapaxes = _funcs.swapaxes
360-
361344
def transpose(self, *axes):
362345
# np.transpose(arr, axis=None) but arr.transpose(*axes)
363346
return _funcs.transpose(self, axes)
@@ -366,51 +349,18 @@ def reshape(self, *shape, order="C"):
366349
# arr.reshape(shape) and arr.reshape(*shape)
367350
return _funcs.reshape(self, shape, order=order)
368351

369-
ravel = _funcs.ravel
370-
flatten = _funcs._flatten
371-
372352
def resize(self, *new_shape, refcheck=False):
373353
# ndarray.resize works in-place (may cause a reallocation though)
374354
self.tensor = _funcs_impl._ndarray_resize(
375355
self.tensor, new_shape, refcheck=refcheck
376356
)
377357

378-
nonzero = _funcs.nonzero
379-
clip = _funcs.clip
380-
repeat = _funcs.repeat
381-
382-
diagonal = _funcs.diagonal
383-
trace = _funcs.trace
384-
dot = _funcs.dot
385-
386358
### sorting ###
387359

388360
def sort(self, axis=-1, kind=None, order=None):
389361
# ndarray.sort works in-place
390362
_funcs.copyto(self, _funcs.sort(self, axis, kind, order))
391363

392-
argsort = _funcs.argsort
393-
searchsorted = _funcs.searchsorted
394-
395-
### reductions ###
396-
argmax = _funcs.argmax
397-
argmin = _funcs.argmin
398-
399-
any = _funcs.any
400-
all = _funcs.all
401-
max = _funcs.max
402-
min = _funcs.min
403-
ptp = _funcs.ptp
404-
405-
sum = _funcs.sum
406-
prod = _funcs.prod
407-
mean = _funcs.mean
408-
var = _funcs.var
409-
std = _funcs.std
410-
411-
cumsum = _funcs.cumsum
412-
cumprod = _funcs.cumprod
413-
414364
### indexing ###
415365
def item(self, *args):
416366
# Mimic NumPy's implementation with three special cases (no arguments,
@@ -447,8 +397,6 @@ def __setitem__(self, index, value):
447397
value = _util.cast_if_needed(value, self.tensor.dtype)
448398
return self.tensor.__setitem__(index, value)
449399

450-
take = _funcs.take
451-
452400

453401
# This is the ideally the only place which talks to ndarray directly.
454402
# The rest goes through asarray (preferred) or array.
@@ -487,9 +435,7 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
487435
return ndarray(tensor)
488436

489437

490-
def asarray(a, dtype=None, order=None, *, like=None):
491-
if order is None:
492-
order = "K"
438+
def asarray(a, dtype=None, order="K", *, like=None):
493439
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
494440

495441

0 commit comments

Comments
 (0)