23
23
AxisLike ,
24
24
DTypeLike ,
25
25
NDArray ,
26
+ NotImplementedType ,
26
27
OutArray ,
27
28
normalize_array_like ,
28
29
)
33
34
###### array creation routines
34
35
35
36
36
- def copy (a : ArrayLike , order : NotImplemented = "K" , subok : NotImplemented = False ):
37
+ def copy (
38
+ a : ArrayLike , order : NotImplementedType = "K" , subok : NotImplementedType = False
39
+ ):
37
40
return a .clone ()
38
41
39
42
40
43
def copyto (
41
- dst : NDArray , src : ArrayLike , casting = "same_kind" , where : NotImplemented = NoValue
44
+ dst : NDArray ,
45
+ src : ArrayLike ,
46
+ casting = "same_kind" ,
47
+ where : NotImplementedType = NoValue ,
42
48
):
43
49
(src ,) = _util .typecast_tensors ((src ,), dst .dtype , casting = casting )
44
50
dst .copy_ (src )
@@ -320,7 +326,7 @@ def arange(
320
326
step : Optional [ArrayLike ] = 1 ,
321
327
dtype : DTypeLike = None ,
322
328
* ,
323
- like : NotImplemented = None ,
329
+ like : NotImplementedType = None ,
324
330
):
325
331
if step == 0 :
326
332
raise ZeroDivisionError
@@ -365,9 +371,9 @@ def arange(
365
371
def empty (
366
372
shape ,
367
373
dtype : DTypeLike = float ,
368
- order : NotImplemented = "C" ,
374
+ order : NotImplementedType = "C" ,
369
375
* ,
370
- like : NotImplemented = None ,
376
+ like : NotImplementedType = None ,
371
377
):
372
378
if dtype is None :
373
379
dtype = _dtypes_impl .default_float_dtype
@@ -381,8 +387,8 @@ def empty(
381
387
def empty_like (
382
388
prototype : ArrayLike ,
383
389
dtype : DTypeLike = None ,
384
- order : NotImplemented = "K" ,
385
- subok : NotImplemented = False ,
390
+ order : NotImplementedType = "K" ,
391
+ subok : NotImplementedType = False ,
386
392
shape = None ,
387
393
):
388
394
result = torch .empty_like (prototype , dtype = dtype )
@@ -395,9 +401,9 @@ def full(
395
401
shape ,
396
402
fill_value : ArrayLike ,
397
403
dtype : DTypeLike = None ,
398
- order : NotImplemented = "C" ,
404
+ order : NotImplementedType = "C" ,
399
405
* ,
400
- like : NotImplemented = None ,
406
+ like : NotImplementedType = None ,
401
407
):
402
408
if isinstance (shape , int ):
403
409
shape = (shape ,)
@@ -412,8 +418,8 @@ def full_like(
412
418
a : ArrayLike ,
413
419
fill_value ,
414
420
dtype : DTypeLike = None ,
415
- order : NotImplemented = "K" ,
416
- subok : NotImplemented = False ,
421
+ order : NotImplementedType = "K" ,
422
+ subok : NotImplementedType = False ,
417
423
shape = None ,
418
424
):
419
425
# XXX: fill_value broadcasts
@@ -426,9 +432,9 @@ def full_like(
426
432
def ones (
427
433
shape ,
428
434
dtype : DTypeLike = None ,
429
- order : NotImplemented = "C" ,
435
+ order : NotImplementedType = "C" ,
430
436
* ,
431
- like : NotImplemented = None ,
437
+ like : NotImplementedType = None ,
432
438
):
433
439
if dtype is None :
434
440
dtype = _dtypes_impl .default_float_dtype
@@ -438,8 +444,8 @@ def ones(
438
444
def ones_like (
439
445
a : ArrayLike ,
440
446
dtype : DTypeLike = None ,
441
- order : NotImplemented = "K" ,
442
- subok : NotImplemented = False ,
447
+ order : NotImplementedType = "K" ,
448
+ subok : NotImplementedType = False ,
443
449
shape = None ,
444
450
):
445
451
result = torch .ones_like (a , dtype = dtype )
@@ -451,9 +457,9 @@ def ones_like(
451
457
def zeros (
452
458
shape ,
453
459
dtype : DTypeLike = None ,
454
- order : NotImplemented = "C" ,
460
+ order : NotImplementedType = "C" ,
455
461
* ,
456
- like : NotImplemented = None ,
462
+ like : NotImplementedType = None ,
457
463
):
458
464
if dtype is None :
459
465
dtype = _dtypes_impl .default_float_dtype
@@ -463,8 +469,8 @@ def zeros(
463
469
def zeros_like (
464
470
a : ArrayLike ,
465
471
dtype : DTypeLike = None ,
466
- order : NotImplemented = "K" ,
467
- subok : NotImplemented = False ,
472
+ order : NotImplementedType = "K" ,
473
+ subok : NotImplementedType = False ,
468
474
shape = None ,
469
475
):
470
476
result = torch .zeros_like (a , dtype = dtype )
@@ -647,14 +653,14 @@ def rot90(m: ArrayLike, k=1, axes=(0, 1)):
647
653
# ### broadcasting and indices ###
648
654
649
655
650
- def broadcast_to (array : ArrayLike , shape , subok : NotImplemented = False ):
656
+ def broadcast_to (array : ArrayLike , shape , subok : NotImplementedType = False ):
651
657
return torch .broadcast_to (array , size = shape )
652
658
653
659
654
660
from torch import broadcast_shapes
655
661
656
662
657
- def broadcast_arrays (* args : ArrayLike , subok : NotImplemented = False ):
663
+ def broadcast_arrays (* args : ArrayLike , subok : NotImplementedType = False ):
658
664
return torch .broadcast_tensors (* args )
659
665
660
666
@@ -740,7 +746,7 @@ def triu_indices_from(arr: ArrayLike, k=0):
740
746
return tuple (result )
741
747
742
748
743
- def tri (N , M = None , k = 0 , dtype : DTypeLike = float , * , like : NotImplemented = None ):
749
+ def tri (N , M = None , k = 0 , dtype : DTypeLike = float , * , like : NotImplementedType = None ):
744
750
if M is None :
745
751
M = N
746
752
tensor = torch .ones ((N , M ), dtype = dtype )
@@ -758,7 +764,7 @@ def nanmean(
758
764
out : Optional [OutArray ] = None ,
759
765
keepdims = NoValue ,
760
766
* ,
761
- where : NotImplemented = NoValue ,
767
+ where : NotImplementedType = NoValue ,
762
768
):
763
769
# XXX: this needs to be rewritten
764
770
if dtype is None :
@@ -892,7 +898,7 @@ def take(
892
898
indices : ArrayLike ,
893
899
axis = None ,
894
900
out : Optional [OutArray ] = None ,
895
- mode : NotImplemented = "raise" ,
901
+ mode : NotImplementedType = "raise" ,
896
902
):
897
903
(a ,), axis = _util .axis_none_ravel (a , axis = axis )
898
904
axis = _util .normalize_axis_index (axis , a .ndim )
@@ -923,12 +929,12 @@ def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
923
929
924
930
def unique (
925
931
ar : ArrayLike ,
926
- return_index : NotImplemented = False ,
932
+ return_index : NotImplementedType = False ,
927
933
return_inverse = False ,
928
934
return_counts = False ,
929
935
axis = None ,
930
936
* ,
931
- equal_nan : NotImplemented = True ,
937
+ equal_nan : NotImplementedType = True ,
932
938
):
933
939
if axis is None :
934
940
ar = ar .ravel ()
@@ -1074,9 +1080,9 @@ def eye(
1074
1080
M = None ,
1075
1081
k = 0 ,
1076
1082
dtype : DTypeLike = float ,
1077
- order : NotImplemented = "C" ,
1083
+ order : NotImplementedType = "C" ,
1078
1084
* ,
1079
- like : NotImplemented = None ,
1085
+ like : NotImplementedType = None ,
1080
1086
):
1081
1087
if M is None :
1082
1088
M = N
@@ -1085,7 +1091,7 @@ def eye(
1085
1091
return z
1086
1092
1087
1093
1088
- def identity (n , dtype : DTypeLike = None , * , like : NotImplemented = None ):
1094
+ def identity (n , dtype : DTypeLike = None , * , like : NotImplementedType = None ):
1089
1095
return torch .eye (n , dtype = dtype )
1090
1096
1091
1097
@@ -1230,14 +1236,14 @@ def _sort_helper(tensor, axis, kind, order):
1230
1236
return tensor , axis , stable
1231
1237
1232
1238
1233
- def sort (a : ArrayLike , axis = - 1 , kind = None , order : NotImplemented = None ):
1239
+ def sort (a : ArrayLike , axis = - 1 , kind = None , order : NotImplementedType = None ):
1234
1240
# `order` keyword arg is only relevant for structured dtypes; so not supported here.
1235
1241
a , axis , stable = _sort_helper (a , axis , kind , order )
1236
1242
result = torch .sort (a , dim = axis , stable = stable )
1237
1243
return result .values
1238
1244
1239
1245
1240
- def argsort (a : ArrayLike , axis = - 1 , kind = None , order : NotImplemented = None ):
1246
+ def argsort (a : ArrayLike , axis = - 1 , kind = None , order : NotImplementedType = None ):
1241
1247
a , axis , stable = _sort_helper (a , axis , kind , order )
1242
1248
return torch .argsort (a , dim = axis , stable = stable )
1243
1249
@@ -1316,7 +1322,7 @@ def squeeze(a: ArrayLike, axis=None):
1316
1322
return result
1317
1323
1318
1324
1319
- def reshape (a : ArrayLike , newshape , order : NotImplemented = "C" ):
1325
+ def reshape (a : ArrayLike , newshape , order : NotImplementedType = "C" ):
1320
1326
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
1321
1327
newshape = newshape [0 ] if len (newshape ) == 1 else newshape
1322
1328
return a .reshape (newshape )
@@ -1342,14 +1348,14 @@ def transpose(a: ArrayLike, axes=None):
1342
1348
return result
1343
1349
1344
1350
1345
- def ravel (a : ArrayLike , order : NotImplemented = "C" ):
1351
+ def ravel (a : ArrayLike , order : NotImplementedType = "C" ):
1346
1352
return torch .ravel (a )
1347
1353
1348
1354
1349
1355
# leading underscore since arr.flatten exists but np.flatten does not
1350
1356
1351
1357
1352
- def _flatten (a : ArrayLike , order : NotImplemented = "C" ):
1358
+ def _flatten (a : ArrayLike , order : NotImplementedType = "C" ):
1353
1359
# may return a copy
1354
1360
return torch .flatten (a )
1355
1361
@@ -1398,8 +1404,8 @@ def sum(
1398
1404
dtype : DTypeLike = None ,
1399
1405
out : Optional [OutArray ] = None ,
1400
1406
keepdims = NoValue ,
1401
- initial : NotImplemented = NoValue ,
1402
- where : NotImplemented = NoValue ,
1407
+ initial : NotImplementedType = NoValue ,
1408
+ where : NotImplementedType = NoValue ,
1403
1409
):
1404
1410
result = _impl .sum (
1405
1411
a , axis = axis , dtype = dtype , initial = initial , where = where , keepdims = keepdims
@@ -1413,8 +1419,8 @@ def prod(
1413
1419
dtype : DTypeLike = None ,
1414
1420
out : Optional [OutArray ] = None ,
1415
1421
keepdims = NoValue ,
1416
- initial : NotImplemented = NoValue ,
1417
- where : NotImplemented = NoValue ,
1422
+ initial : NotImplementedType = NoValue ,
1423
+ where : NotImplementedType = NoValue ,
1418
1424
):
1419
1425
result = _impl .prod (
1420
1426
a , axis = axis , dtype = dtype , initial = initial , where = where , keepdims = keepdims
@@ -1432,7 +1438,7 @@ def mean(
1432
1438
out : Optional [OutArray ] = None ,
1433
1439
keepdims = NoValue ,
1434
1440
* ,
1435
- where : NotImplemented = NoValue ,
1441
+ where : NotImplementedType = NoValue ,
1436
1442
):
1437
1443
result = _impl .mean (a , axis = axis , dtype = dtype , where = NoValue , keepdims = keepdims )
1438
1444
return result
@@ -1446,7 +1452,7 @@ def var(
1446
1452
ddof = 0 ,
1447
1453
keepdims = NoValue ,
1448
1454
* ,
1449
- where : NotImplemented = NoValue ,
1455
+ where : NotImplementedType = NoValue ,
1450
1456
):
1451
1457
result = _impl .var (
1452
1458
a , axis = axis , dtype = dtype , ddof = ddof , where = where , keepdims = keepdims
@@ -1462,7 +1468,7 @@ def std(
1462
1468
ddof = 0 ,
1463
1469
keepdims = NoValue ,
1464
1470
* ,
1465
- where : NotImplemented = NoValue ,
1471
+ where : NotImplementedType = NoValue ,
1466
1472
):
1467
1473
result = _impl .std (
1468
1474
a , axis = axis , dtype = dtype , ddof = ddof , where = where , keepdims = keepdims
@@ -1497,8 +1503,8 @@ def amax(
1497
1503
axis : AxisLike = None ,
1498
1504
out : Optional [OutArray ] = None ,
1499
1505
keepdims = NoValue ,
1500
- initial : NotImplemented = NoValue ,
1501
- where : NotImplemented = NoValue ,
1506
+ initial : NotImplementedType = NoValue ,
1507
+ where : NotImplementedType = NoValue ,
1502
1508
):
1503
1509
result = _impl .max (a , axis = axis , initial = initial , where = where , keepdims = keepdims )
1504
1510
return result
@@ -1512,8 +1518,8 @@ def amin(
1512
1518
axis : AxisLike = None ,
1513
1519
out : Optional [OutArray ] = None ,
1514
1520
keepdims = NoValue ,
1515
- initial : NotImplemented = NoValue ,
1516
- where : NotImplemented = NoValue ,
1521
+ initial : NotImplementedType = NoValue ,
1522
+ where : NotImplementedType = NoValue ,
1517
1523
):
1518
1524
result = _impl .min (a , axis = axis , initial = initial , where = where , keepdims = keepdims )
1519
1525
return result
@@ -1538,7 +1544,7 @@ def all(
1538
1544
out : Optional [OutArray ] = None ,
1539
1545
keepdims = NoValue ,
1540
1546
* ,
1541
- where : NotImplemented = NoValue ,
1547
+ where : NotImplementedType = NoValue ,
1542
1548
):
1543
1549
result = _impl .all (a , axis = axis , where = where , keepdims = keepdims )
1544
1550
return result
@@ -1550,7 +1556,7 @@ def any(
1550
1556
out : Optional [OutArray ] = None ,
1551
1557
keepdims = NoValue ,
1552
1558
* ,
1553
- where : NotImplemented = NoValue ,
1559
+ where : NotImplementedType = NoValue ,
1554
1560
):
1555
1561
result = _impl .any (a , axis = axis , where = where , keepdims = keepdims )
1556
1562
return result
@@ -1593,7 +1599,7 @@ def quantile(
1593
1599
method = "linear" ,
1594
1600
keepdims = False ,
1595
1601
* ,
1596
- interpolation : NotImplemented = None ,
1602
+ interpolation : NotImplementedType = None ,
1597
1603
):
1598
1604
result = _impl .quantile (
1599
1605
a ,
@@ -1616,7 +1622,7 @@ def percentile(
1616
1622
method = "linear" ,
1617
1623
keepdims = False ,
1618
1624
* ,
1619
- interpolation : NotImplemented = None ,
1625
+ interpolation : NotImplementedType = None ,
1620
1626
):
1621
1627
result = _impl .percentile (
1622
1628
a ,
0 commit comments