Skip to content

Commit 1c80005

Browse files
committed
MAINT: use NotImplementedType not NotImplemented
1 parent 64213f7 commit 1c80005

File tree

4 files changed

+88
-64
lines changed

4 files changed

+88
-64
lines changed

torch_np/_funcs_impl.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AxisLike,
2424
DTypeLike,
2525
NDArray,
26+
NotImplementedType,
2627
OutArray,
2728
normalize_array_like,
2829
)
@@ -33,12 +34,17 @@
3334
###### array creation routines
3435

3536

36-
def copy(a: ArrayLike, order: NotImplemented = "K", subok: NotImplemented = False):
37+
def copy(
38+
a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False
39+
):
3740
return a.clone()
3841

3942

4043
def copyto(
41-
dst: NDArray, src: ArrayLike, casting="same_kind", where: NotImplemented = NoValue
44+
dst: NDArray,
45+
src: ArrayLike,
46+
casting="same_kind",
47+
where: NotImplementedType = NoValue,
4248
):
4349
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
4450
dst.copy_(src)
@@ -320,7 +326,7 @@ def arange(
320326
step: Optional[ArrayLike] = 1,
321327
dtype: DTypeLike = None,
322328
*,
323-
like: NotImplemented = None,
329+
like: NotImplementedType = None,
324330
):
325331
if step == 0:
326332
raise ZeroDivisionError
@@ -365,9 +371,9 @@ def arange(
365371
def empty(
366372
shape,
367373
dtype: DTypeLike = float,
368-
order: NotImplemented = "C",
374+
order: NotImplementedType = "C",
369375
*,
370-
like: NotImplemented = None,
376+
like: NotImplementedType = None,
371377
):
372378
if dtype is None:
373379
dtype = _dtypes_impl.default_float_dtype
@@ -381,8 +387,8 @@ def empty(
381387
def empty_like(
382388
prototype: ArrayLike,
383389
dtype: DTypeLike = None,
384-
order: NotImplemented = "K",
385-
subok: NotImplemented = False,
390+
order: NotImplementedType = "K",
391+
subok: NotImplementedType = False,
386392
shape=None,
387393
):
388394
result = torch.empty_like(prototype, dtype=dtype)
@@ -395,9 +401,9 @@ def full(
395401
shape,
396402
fill_value: ArrayLike,
397403
dtype: DTypeLike = None,
398-
order: NotImplemented = "C",
404+
order: NotImplementedType = "C",
399405
*,
400-
like: NotImplemented = None,
406+
like: NotImplementedType = None,
401407
):
402408
if isinstance(shape, int):
403409
shape = (shape,)
@@ -412,8 +418,8 @@ def full_like(
412418
a: ArrayLike,
413419
fill_value,
414420
dtype: DTypeLike = None,
415-
order: NotImplemented = "K",
416-
subok: NotImplemented = False,
421+
order: NotImplementedType = "K",
422+
subok: NotImplementedType = False,
417423
shape=None,
418424
):
419425
# XXX: fill_value broadcasts
@@ -426,9 +432,9 @@ def full_like(
426432
def ones(
427433
shape,
428434
dtype: DTypeLike = None,
429-
order: NotImplemented = "C",
435+
order: NotImplementedType = "C",
430436
*,
431-
like: NotImplemented = None,
437+
like: NotImplementedType = None,
432438
):
433439
if dtype is None:
434440
dtype = _dtypes_impl.default_float_dtype
@@ -438,8 +444,8 @@ def ones(
438444
def ones_like(
439445
a: ArrayLike,
440446
dtype: DTypeLike = None,
441-
order: NotImplemented = "K",
442-
subok: NotImplemented = False,
447+
order: NotImplementedType = "K",
448+
subok: NotImplementedType = False,
443449
shape=None,
444450
):
445451
result = torch.ones_like(a, dtype=dtype)
@@ -451,9 +457,9 @@ def ones_like(
451457
def zeros(
452458
shape,
453459
dtype: DTypeLike = None,
454-
order: NotImplemented = "C",
460+
order: NotImplementedType = "C",
455461
*,
456-
like: NotImplemented = None,
462+
like: NotImplementedType = None,
457463
):
458464
if dtype is None:
459465
dtype = _dtypes_impl.default_float_dtype
@@ -463,8 +469,8 @@ def zeros(
463469
def zeros_like(
464470
a: ArrayLike,
465471
dtype: DTypeLike = None,
466-
order: NotImplemented = "K",
467-
subok: NotImplemented = False,
472+
order: NotImplementedType = "K",
473+
subok: NotImplementedType = False,
468474
shape=None,
469475
):
470476
result = torch.zeros_like(a, dtype=dtype)
@@ -647,14 +653,14 @@ def rot90(m: ArrayLike, k=1, axes=(0, 1)):
647653
# ### broadcasting and indices ###
648654

649655

650-
def broadcast_to(array: ArrayLike, shape, subok: NotImplemented = False):
656+
def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False):
651657
return torch.broadcast_to(array, size=shape)
652658

653659

654660
from torch import broadcast_shapes
655661

656662

657-
def broadcast_arrays(*args: ArrayLike, subok: NotImplemented = False):
663+
def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False):
658664
return torch.broadcast_tensors(*args)
659665

660666

@@ -740,7 +746,7 @@ def triu_indices_from(arr: ArrayLike, k=0):
740746
return tuple(result)
741747

742748

743-
def tri(N, M=None, k=0, dtype: DTypeLike = float, *, like: NotImplemented = None):
749+
def tri(N, M=None, k=0, dtype: DTypeLike = float, *, like: NotImplementedType = None):
744750
if M is None:
745751
M = N
746752
tensor = torch.ones((N, M), dtype=dtype)
@@ -758,7 +764,7 @@ def nanmean(
758764
out: Optional[OutArray] = None,
759765
keepdims=NoValue,
760766
*,
761-
where: NotImplemented = NoValue,
767+
where: NotImplementedType = NoValue,
762768
):
763769
# XXX: this needs to be rewritten
764770
if dtype is None:
@@ -892,7 +898,7 @@ def take(
892898
indices: ArrayLike,
893899
axis=None,
894900
out: Optional[OutArray] = None,
895-
mode: NotImplemented = "raise",
901+
mode: NotImplementedType = "raise",
896902
):
897903
(a,), axis = _util.axis_none_ravel(a, axis=axis)
898904
axis = _util.normalize_axis_index(axis, a.ndim)
@@ -923,12 +929,12 @@ def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
923929

924930
def unique(
925931
ar: ArrayLike,
926-
return_index: NotImplemented = False,
932+
return_index: NotImplementedType = False,
927933
return_inverse=False,
928934
return_counts=False,
929935
axis=None,
930936
*,
931-
equal_nan: NotImplemented = True,
937+
equal_nan: NotImplementedType = True,
932938
):
933939
if axis is None:
934940
ar = ar.ravel()
@@ -1074,9 +1080,9 @@ def eye(
10741080
M=None,
10751081
k=0,
10761082
dtype: DTypeLike = float,
1077-
order: NotImplemented = "C",
1083+
order: NotImplementedType = "C",
10781084
*,
1079-
like: NotImplemented = None,
1085+
like: NotImplementedType = None,
10801086
):
10811087
if M is None:
10821088
M = N
@@ -1085,7 +1091,7 @@ def eye(
10851091
return z
10861092

10871093

1088-
def identity(n, dtype: DTypeLike = None, *, like: NotImplemented = None):
1094+
def identity(n, dtype: DTypeLike = None, *, like: NotImplementedType = None):
10891095
return torch.eye(n, dtype=dtype)
10901096

10911097

@@ -1230,14 +1236,14 @@ def _sort_helper(tensor, axis, kind, order):
12301236
return tensor, axis, stable
12311237

12321238

1233-
def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplemented = None):
1239+
def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
12341240
# `order` keyword arg is only relevant for structured dtypes; so not supported here.
12351241
a, axis, stable = _sort_helper(a, axis, kind, order)
12361242
result = torch.sort(a, dim=axis, stable=stable)
12371243
return result.values
12381244

12391245

1240-
def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplemented = None):
1246+
def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
12411247
a, axis, stable = _sort_helper(a, axis, kind, order)
12421248
return torch.argsort(a, dim=axis, stable=stable)
12431249

@@ -1316,7 +1322,7 @@ def squeeze(a: ArrayLike, axis=None):
13161322
return result
13171323

13181324

1319-
def reshape(a: ArrayLike, newshape, order: NotImplemented = "C"):
1325+
def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"):
13201326
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
13211327
newshape = newshape[0] if len(newshape) == 1 else newshape
13221328
return a.reshape(newshape)
@@ -1342,14 +1348,14 @@ def transpose(a: ArrayLike, axes=None):
13421348
return result
13431349

13441350

1345-
def ravel(a: ArrayLike, order: NotImplemented = "C"):
1351+
def ravel(a: ArrayLike, order: NotImplementedType = "C"):
13461352
return torch.ravel(a)
13471353

13481354

13491355
# leading underscore since arr.flatten exists but np.flatten does not
13501356

13511357

1352-
def _flatten(a: ArrayLike, order: NotImplemented = "C"):
1358+
def _flatten(a: ArrayLike, order: NotImplementedType = "C"):
13531359
# may return a copy
13541360
return torch.flatten(a)
13551361

@@ -1398,8 +1404,8 @@ def sum(
13981404
dtype: DTypeLike = None,
13991405
out: Optional[OutArray] = None,
14001406
keepdims=NoValue,
1401-
initial: NotImplemented = NoValue,
1402-
where: NotImplemented = NoValue,
1407+
initial: NotImplementedType = NoValue,
1408+
where: NotImplementedType = NoValue,
14031409
):
14041410
result = _impl.sum(
14051411
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
@@ -1413,8 +1419,8 @@ def prod(
14131419
dtype: DTypeLike = None,
14141420
out: Optional[OutArray] = None,
14151421
keepdims=NoValue,
1416-
initial: NotImplemented = NoValue,
1417-
where: NotImplemented = NoValue,
1422+
initial: NotImplementedType = NoValue,
1423+
where: NotImplementedType = NoValue,
14181424
):
14191425
result = _impl.prod(
14201426
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
@@ -1432,7 +1438,7 @@ def mean(
14321438
out: Optional[OutArray] = None,
14331439
keepdims=NoValue,
14341440
*,
1435-
where: NotImplemented = NoValue,
1441+
where: NotImplementedType = NoValue,
14361442
):
14371443
result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims)
14381444
return result
@@ -1446,7 +1452,7 @@ def var(
14461452
ddof=0,
14471453
keepdims=NoValue,
14481454
*,
1449-
where: NotImplemented = NoValue,
1455+
where: NotImplementedType = NoValue,
14501456
):
14511457
result = _impl.var(
14521458
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
@@ -1462,7 +1468,7 @@ def std(
14621468
ddof=0,
14631469
keepdims=NoValue,
14641470
*,
1465-
where: NotImplemented = NoValue,
1471+
where: NotImplementedType = NoValue,
14661472
):
14671473
result = _impl.std(
14681474
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
@@ -1497,8 +1503,8 @@ def amax(
14971503
axis: AxisLike = None,
14981504
out: Optional[OutArray] = None,
14991505
keepdims=NoValue,
1500-
initial: NotImplemented = NoValue,
1501-
where: NotImplemented = NoValue,
1506+
initial: NotImplementedType = NoValue,
1507+
where: NotImplementedType = NoValue,
15021508
):
15031509
result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
15041510
return result
@@ -1512,8 +1518,8 @@ def amin(
15121518
axis: AxisLike = None,
15131519
out: Optional[OutArray] = None,
15141520
keepdims=NoValue,
1515-
initial: NotImplemented = NoValue,
1516-
where: NotImplemented = NoValue,
1521+
initial: NotImplementedType = NoValue,
1522+
where: NotImplementedType = NoValue,
15171523
):
15181524
result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
15191525
return result
@@ -1538,7 +1544,7 @@ def all(
15381544
out: Optional[OutArray] = None,
15391545
keepdims=NoValue,
15401546
*,
1541-
where: NotImplemented = NoValue,
1547+
where: NotImplementedType = NoValue,
15421548
):
15431549
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
15441550
return result
@@ -1550,7 +1556,7 @@ def any(
15501556
out: Optional[OutArray] = None,
15511557
keepdims=NoValue,
15521558
*,
1553-
where: NotImplemented = NoValue,
1559+
where: NotImplementedType = NoValue,
15541560
):
15551561
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
15561562
return result
@@ -1593,7 +1599,7 @@ def quantile(
15931599
method="linear",
15941600
keepdims=False,
15951601
*,
1596-
interpolation: NotImplemented = None,
1602+
interpolation: NotImplementedType = None,
15971603
):
15981604
result = _impl.quantile(
15991605
a,
@@ -1616,7 +1622,7 @@ def percentile(
16161622
method="linear",
16171623
keepdims=False,
16181624
*,
1619-
interpolation: NotImplemented = None,
1625+
interpolation: NotImplementedType = None,
16201626
):
16211627
result = _impl.percentile(
16221628
a,

torch_np/_normalizations.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
#
2929
OutArray = typing.TypeVar("OutArray")
3030

31+
try:
32+
from typing import NotImplementedType
33+
except ImportError:
34+
NotImplementedType = typing.TypeVar("NotImplementedType")
35+
3136

3237
import inspect
3338

@@ -105,7 +110,7 @@ def normalize_outarray(arg, parm=None):
105110
"NDArray": normalize_ndarray,
106111
"DTypeLike": normalize_dtype,
107112
"AxisLike": normalize_axis_like,
108-
NotImplemented: normalize_not_implemented,
113+
"NotImplementedType": normalize_not_implemented,
109114
}
110115

111116

0 commit comments

Comments
 (0)