@@ -57,6 +57,84 @@ def __getitem__(self, key):
57
57
raise KeyError (f"No flag key '{ key } '" )
58
58
59
59
60
+ def create_method (fn , name = None ):
61
+ name = name or fn .__name__
62
+
63
+ def f (* args , ** kwargs ):
64
+ return fn (* args , ** kwargs )
65
+
66
+ f .__name__ = name
67
+ f .__qualname__ = f"ndarray.{ name } "
68
+ return f
69
+
70
+
71
+ # Map ndarray.name_method -> np.name_func
72
+ # If name_func == None, it means that name_method == name_func
73
+ methods = {
74
+ "clip" : None ,
75
+ "flatten" : "_flatten" ,
76
+ "nonzero" : None ,
77
+ "repeat" : None ,
78
+ "round" : None ,
79
+ "squeeze" : None ,
80
+ "swapaxes" : None ,
81
+ "ravel" : None ,
82
+ # linalg
83
+ "diagonal" : None ,
84
+ "dot" : None ,
85
+ "trace" : None ,
86
+ # sorting
87
+ "argsort" : None ,
88
+ "searchsorted" : None ,
89
+ # reductions
90
+ "argmax" : None ,
91
+ "argmin" : None ,
92
+ "any" : None ,
93
+ "all" : None ,
94
+ "max" : None ,
95
+ "min" : None ,
96
+ "ptp" : None ,
97
+ "sum" : None ,
98
+ "prod" : None ,
99
+ "mean" : None ,
100
+ "var" : None ,
101
+ "std" : None ,
102
+ # scans
103
+ "cumsum" : None ,
104
+ "cumprod" : None ,
105
+ # advanced indexing
106
+ "take" : None ,
107
+ }
108
+
109
+ dunder = {
110
+ "abs" : "absolute" ,
111
+ "invert" : None ,
112
+ "pos" : "positive" ,
113
+ "neg" : "negative" ,
114
+ "gt" : "greater" ,
115
+ "lt" : "less" ,
116
+ "ge" : "greater_equal" ,
117
+ "le" : "less_equal" ,
118
+ }
119
+
120
+ # dunder methods with right-looking and in-place variants
121
+ ri_dunder = {
122
+ "add" : None ,
123
+ "sub" : "subtract" ,
124
+ "mul" : "multiply" ,
125
+ "truediv" : "divide" ,
126
+ "floordiv" : "floor_divide" ,
127
+ "pow" : "float_power" ,
128
+ "mod" : "remainder" ,
129
+ "and" : "bitwise_and" ,
130
+ "or" : "bitwise_or" ,
131
+ "xor" : "bitwise_xor" ,
132
+ "lshift" : "left_shift" ,
133
+ "rshift" : "right_shift" ,
134
+ "matmul" : None ,
135
+ }
136
+
137
+
60
138
##################### ndarray class ###########################
61
139
62
140
@@ -72,6 +150,37 @@ def __init__(self, t=None):
72
150
"either array(...) or zeros/empty(...)"
73
151
)
74
152
153
+ # Register NumPy functions as methods
154
+ for method , name in methods .items ():
155
+ fn = getattr (_funcs , name or method )
156
+ vars ()[method ] = create_method (fn , method )
157
+
158
+ # Regular methods but coming from ufuncs
159
+ conj = create_method (_ufuncs .conjugate , "conj" )
160
+ conjugate = create_method (_ufuncs .conjugate )
161
+
162
+ for method , name in dunder .items ():
163
+ fn = getattr (_ufuncs , name or method )
164
+ method = f"__{ method } __"
165
+ vars ()[method ] = create_method (fn , method )
166
+
167
+ for method , name in ri_dunder .items ():
168
+ fn = getattr (_ufuncs , name or method )
169
+ plain = f"__{ method } __"
170
+ vars ()[plain ] = create_method (fn , plain )
171
+ rvar = f"__r{ method } __"
172
+ vars ()[rvar ] = create_method (lambda self , other , fn = fn : fn (other , self ), rvar )
173
+ ivar = f"__i{ method } __"
174
+ vars ()[ivar ] = create_method (
175
+ lambda self , other , fn = fn : fn (self , other , out = self ), ivar
176
+ )
177
+
178
+ # There's no __idivmod__
179
+ __divmod__ = create_method (_ufuncs .divmod , "__divmod__" )
180
+ __rdivmod__ = create_method (
181
+ lambda self , other : _ufuncs .divmod (other , self ), "__rdivmod__"
182
+ )
183
+
75
184
@property
76
185
def shape (self ):
77
186
return tuple (self .tensor .shape )
@@ -100,18 +209,10 @@ def itemsize(self):
100
209
@property
101
210
def flags (self ):
102
211
# Note contiguous in torch is assumed C-style
103
-
104
- # check if F contiguous
105
- from itertools import accumulate
106
-
107
- f_strides = tuple (accumulate (list (self .tensor .shape ), func = lambda x , y : x * y ))
108
- f_strides = (1 ,) + f_strides [:- 1 ]
109
- is_f_contiguous = f_strides == self .tensor .stride ()
110
-
111
212
return Flags (
112
213
{
113
214
"C_CONTIGUOUS" : self .tensor .is_contiguous (),
114
- "F_CONTIGUOUS" : is_f_contiguous ,
215
+ "F_CONTIGUOUS" : self . T . tensor . is_contiguous () ,
115
216
"OWNDATA" : self .tensor ._base is None ,
116
217
"WRITEABLE" : True , # pytorch does not have readonly tensors
117
218
}
@@ -145,14 +246,11 @@ def imag(self):
145
246
def imag (self , value ):
146
247
self .tensor .imag = asarray (value ).tensor
147
248
148
- round = _funcs .round
149
-
150
249
# ctors
151
250
def astype (self , dtype ):
152
- newt = ndarray ()
153
251
torch_dtype = _dtypes .dtype (dtype ).torch_dtype
154
- newt . tensor = self .tensor .to (torch_dtype )
155
- return newt
252
+ t = self .tensor .to (torch_dtype )
253
+ return ndarray ( t )
156
254
157
255
def copy (self , order = "C" ):
158
256
if order != "C" :
@@ -182,7 +280,7 @@ def __str__(self):
182
280
.replace ("dtype=torch." , "dtype=" )
183
281
)
184
282
185
- __repr__ = __str__
283
+ __repr__ = create_method ( __str__ )
186
284
187
285
### comparisons ###
188
286
def __eq__ (self , other ):
@@ -201,11 +299,6 @@ def __ne__(self, other):
201
299
falsy = torch .full (self .shape , fill_value = True , dtype = bool )
202
300
return asarray (falsy )
203
301
204
- __gt__ = _ufuncs .greater
205
- __lt__ = _ufuncs .less
206
- __ge__ = _ufuncs .greater_equal
207
- __le__ = _ufuncs .less_equal
208
-
209
302
def __bool__ (self ):
210
303
try :
211
304
return bool (self .tensor )
@@ -247,117 +340,7 @@ def is_integer(self):
247
340
def __len__ (self ):
248
341
return self .tensor .shape [0 ]
249
342
250
- ### arithmetic ###
251
-
252
- # add, self + other
253
- __add__ = __radd__ = _ufuncs .add
254
-
255
- def __iadd__ (self , other ):
256
- return _ufuncs .add (self , other , out = self )
257
-
258
- # sub, self - other
259
- __sub__ = _ufuncs .subtract
260
-
261
- # XXX: generate a function just for this? AND other non-commutative ops.
262
- def __rsub__ (self , other ):
263
- return _ufuncs .subtract (other , self )
264
-
265
- def __isub__ (self , other ):
266
- return _ufuncs .subtract (self , other , out = self )
267
-
268
- # mul, self * other
269
- __mul__ = __rmul__ = _ufuncs .multiply
270
-
271
- def __imul__ (self , other ):
272
- return _ufuncs .multiply (self , other , out = self )
273
-
274
- # div, self / other
275
- __truediv__ = _ufuncs .divide
276
-
277
- def __rtruediv__ (self , other ):
278
- return _ufuncs .divide (other , self )
279
-
280
- def __itruediv__ (self , other ):
281
- return _ufuncs .divide (self , other , out = self )
282
-
283
- # floordiv, self // other
284
- __floordiv__ = _ufuncs .floor_divide
285
-
286
- def __rfloordiv__ (self , other ):
287
- return _ufuncs .floor_divide (other , self )
288
-
289
- def __ifloordiv__ (self , other ):
290
- return _ufuncs .floor_divide (self , other , out = self )
291
-
292
- __divmod__ = _ufuncs .divmod
293
-
294
- # power, self**exponent
295
- __pow__ = __rpow__ = _ufuncs .float_power
296
-
297
- def __rpow__ (self , exponent ):
298
- return _ufuncs .float_power (exponent , self )
299
-
300
- def __ipow__ (self , exponent ):
301
- return _ufuncs .float_power (self , exponent , out = self )
302
-
303
- # remainder, self % other
304
- __mod__ = __rmod__ = _ufuncs .remainder
305
-
306
- def __imod__ (self , other ):
307
- return _ufuncs .remainder (self , other , out = self )
308
-
309
- # bitwise ops
310
- # and, self & other
311
- __and__ = __rand__ = _ufuncs .bitwise_and
312
-
313
- def __iand__ (self , other ):
314
- return _ufuncs .bitwise_and (self , other , out = self )
315
-
316
- # or, self | other
317
- __or__ = __ror__ = _ufuncs .bitwise_or
318
-
319
- def __ior__ (self , other ):
320
- return _ufuncs .bitwise_or (self , other , out = self )
321
-
322
- # xor, self ^ other
323
- __xor__ = __rxor__ = _ufuncs .bitwise_xor
324
-
325
- def __ixor__ (self , other ):
326
- return _ufuncs .bitwise_xor (self , other , out = self )
327
-
328
- # bit shifts
329
- __lshift__ = __rlshift__ = _ufuncs .left_shift
330
-
331
- def __ilshift__ (self , other ):
332
- return _ufuncs .left_shift (self , other , out = self )
333
-
334
- __rshift__ = __rrshift__ = _ufuncs .right_shift
335
-
336
- def __irshift__ (self , other ):
337
- return _ufuncs .right_shift (self , other , out = self )
338
-
339
- __matmul__ = _ufuncs .matmul
340
-
341
- def __rmatmul__ (self , other ):
342
- return _ufuncs .matmul (other , self )
343
-
344
- def __imatmul__ (self , other ):
345
- return _ufuncs .matmul (self , other , out = self )
346
-
347
- # unary ops
348
- __invert__ = _ufuncs .invert
349
- __abs__ = _ufuncs .absolute
350
- __pos__ = _ufuncs .positive
351
- __neg__ = _ufuncs .negative
352
-
353
- conjugate = _ufuncs .conjugate
354
- conj = conjugate
355
-
356
343
### methods to match namespace functions
357
-
358
- squeeze = _funcs .squeeze
359
- swapaxes = _funcs .swapaxes
360
-
361
344
def transpose (self , * axes ):
362
345
# np.transpose(arr, axis=None) but arr.transpose(*axes)
363
346
return _funcs .transpose (self , axes )
@@ -366,51 +349,18 @@ def reshape(self, *shape, order="C"):
366
349
# arr.reshape(shape) and arr.reshape(*shape)
367
350
return _funcs .reshape (self , shape , order = order )
368
351
369
- ravel = _funcs .ravel
370
- flatten = _funcs ._flatten
371
-
372
352
def resize (self , * new_shape , refcheck = False ):
373
353
# ndarray.resize works in-place (may cause a reallocation though)
374
354
self .tensor = _funcs_impl ._ndarray_resize (
375
355
self .tensor , new_shape , refcheck = refcheck
376
356
)
377
357
378
- nonzero = _funcs .nonzero
379
- clip = _funcs .clip
380
- repeat = _funcs .repeat
381
-
382
- diagonal = _funcs .diagonal
383
- trace = _funcs .trace
384
- dot = _funcs .dot
385
-
386
358
### sorting ###
387
359
388
360
def sort (self , axis = - 1 , kind = None , order = None ):
389
361
# ndarray.sort works in-place
390
362
_funcs .copyto (self , _funcs .sort (self , axis , kind , order ))
391
363
392
- argsort = _funcs .argsort
393
- searchsorted = _funcs .searchsorted
394
-
395
- ### reductions ###
396
- argmax = _funcs .argmax
397
- argmin = _funcs .argmin
398
-
399
- any = _funcs .any
400
- all = _funcs .all
401
- max = _funcs .max
402
- min = _funcs .min
403
- ptp = _funcs .ptp
404
-
405
- sum = _funcs .sum
406
- prod = _funcs .prod
407
- mean = _funcs .mean
408
- var = _funcs .var
409
- std = _funcs .std
410
-
411
- cumsum = _funcs .cumsum
412
- cumprod = _funcs .cumprod
413
-
414
364
### indexing ###
415
365
def item (self , * args ):
416
366
# Mimic NumPy's implementation with three special cases (no arguments,
@@ -447,8 +397,6 @@ def __setitem__(self, index, value):
447
397
value = _util .cast_if_needed (value , self .tensor .dtype )
448
398
return self .tensor .__setitem__ (index , value )
449
399
450
- take = _funcs .take
451
-
452
400
453
401
# This is the ideally the only place which talks to ndarray directly.
454
402
# The rest goes through asarray (preferred) or array.
@@ -487,9 +435,7 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
487
435
return ndarray (tensor )
488
436
489
437
490
- def asarray (a , dtype = None , order = None , * , like = None ):
491
- if order is None :
492
- order = "K"
438
+ def asarray (a , dtype = None , order = "K" , * , like = None ):
493
439
return array (a , dtype = dtype , order = order , like = like , copy = False , ndmin = 0 )
494
440
495
441
0 commit comments