Skip to content

Commit e2bd8d9

Browse files
committed
MAINT: nuke NoValue
1 parent ed292ed commit e2bd8d9

File tree

3 files changed

+45
-53
lines changed

3 files changed

+45
-53
lines changed

torch_np/_detail/_reductions.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,13 @@
44
Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
55
"""
66

7+
import functools
78
import typing
89

910
import torch
1011

1112
from . import _dtypes_impl, _util
1213

13-
NoValue = _util.NoValue
14-
15-
16-
import functools
17-
1814
############# XXX
1915
### From _util.axis_expand_func
2016

@@ -51,7 +47,7 @@ def wrapped(tensor, axis, *args, **kwds):
5147

5248
def emulate_keepdims(func):
5349
@functools.wraps(func)
54-
def wrapped(tensor, axis=None, keepdims=NoValue, *args, **kwds):
50+
def wrapped(tensor, axis=None, keepdims=None, *args, **kwds):
5551
result = func(tensor, axis=axis, *args, **kwds)
5652
if keepdims:
5753
result = _util.apply_keepdims(result, axis, tensor.ndim)
@@ -133,7 +129,7 @@ def argmin(tensor, axis=None):
133129

134130
@emulate_keepdims
135131
@deco_axis_expand
136-
def any(tensor, axis=None, *, where=NoValue):
132+
def any(tensor, axis=None, *, where=None):
137133
axis = _util.allow_only_single_axis(axis)
138134

139135
if axis is None:
@@ -145,7 +141,7 @@ def any(tensor, axis=None, *, where=NoValue):
145141

146142
@emulate_keepdims
147143
@deco_axis_expand
148-
def all(tensor, axis=None, *, where=NoValue):
144+
def all(tensor, axis=None, *, where=None):
149145
axis = _util.allow_only_single_axis(axis)
150146

151147
if axis is None:
@@ -157,13 +153,13 @@ def all(tensor, axis=None, *, where=NoValue):
157153

158154
@emulate_keepdims
159155
@deco_axis_expand
160-
def max(tensor, axis=None, initial=NoValue, where=NoValue):
156+
def max(tensor, axis=None, initial=None, where=None):
161157
return tensor.amax(axis)
162158

163159

164160
@emulate_keepdims
165161
@deco_axis_expand
166-
def min(tensor, axis=None, initial=NoValue, where=NoValue):
162+
def min(tensor, axis=None, initial=None, where=None):
167163
return tensor.amin(axis)
168164

169165

@@ -175,7 +171,7 @@ def ptp(tensor, axis=None):
175171

176172
@emulate_keepdims
177173
@deco_axis_expand
178-
def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
174+
def sum(tensor, axis=None, dtype=None, initial=None, where=None):
179175
assert dtype is None or isinstance(dtype, torch.dtype)
180176

181177
if dtype == torch.bool:
@@ -191,7 +187,7 @@ def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
191187

192188
@emulate_keepdims
193189
@deco_axis_expand
194-
def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
190+
def prod(tensor, axis=None, dtype=None, initial=None, where=None):
195191
axis = _util.allow_only_single_axis(axis)
196192

197193
if dtype == torch.bool:
@@ -207,7 +203,7 @@ def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
207203

208204
@emulate_keepdims
209205
@deco_axis_expand
210-
def mean(tensor, axis=None, dtype=None, *, where=NoValue):
206+
def mean(tensor, axis=None, dtype=None, *, where=None):
211207
dtype = _atleast_float(dtype, tensor.dtype)
212208

213209
is_half = dtype == torch.float16
@@ -228,7 +224,7 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue):
228224

229225
@emulate_keepdims
230226
@deco_axis_expand
231-
def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
227+
def std(tensor, axis=None, dtype=None, ddof=0, *, where=None):
232228
dtype = _atleast_float(dtype, tensor.dtype)
233229
tensor = _util.cast_if_needed(tensor, dtype)
234230
result = tensor.std(dim=axis, correction=ddof)
@@ -238,7 +234,7 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
238234

239235
@emulate_keepdims
240236
@deco_axis_expand
241-
def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
237+
def var(tensor, axis=None, dtype=None, ddof=0, *, where=None):
242238
dtype = _atleast_float(dtype, tensor.dtype)
243239
tensor = _util.cast_if_needed(tensor, dtype)
244240
result = tensor.var(dim=axis, correction=ddof)

torch_np/_detail/_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from . import _dtypes_impl
99

10-
NoValue = None
1110

1211
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
1312
def is_sequence(seq):

torch_np/_funcs_impl.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828
normalize_array_like,
2929
)
3030

31-
NoValue = _util.NoValue
32-
33-
3431
###### array creation routines
3532

3633

@@ -44,7 +41,7 @@ def copyto(
4441
dst: NDArray,
4542
src: ArrayLike,
4643
casting="same_kind",
47-
where: NotImplementedType = NoValue,
44+
where: NotImplementedType = None,
4845
):
4946
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
5047
dst.copy_(src)
@@ -511,8 +508,8 @@ def corrcoef(
511508
x: ArrayLike,
512509
y: Optional[ArrayLike] = None,
513510
rowvar=True,
514-
bias=NoValue,
515-
ddof=NoValue,
511+
bias=None,
512+
ddof=None,
516513
*,
517514
dtype: DTypeLike = None,
518515
):
@@ -762,9 +759,9 @@ def nanmean(
762759
axis=None,
763760
dtype: DTypeLike = None,
764761
out: Optional[OutArray] = None,
765-
keepdims=NoValue,
762+
keepdims=None,
766763
*,
767-
where: NotImplementedType = NoValue,
764+
where: NotImplementedType = None,
768765
):
769766
# XXX: this needs to be rewritten
770767
if dtype is None:
@@ -1403,9 +1400,9 @@ def sum(
14031400
axis: AxisLike = None,
14041401
dtype: DTypeLike = None,
14051402
out: Optional[OutArray] = None,
1406-
keepdims=NoValue,
1407-
initial: NotImplementedType = NoValue,
1408-
where: NotImplementedType = NoValue,
1403+
keepdims=None,
1404+
initial: NotImplementedType = None,
1405+
where: NotImplementedType = None,
14091406
):
14101407
result = _impl.sum(
14111408
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
@@ -1418,9 +1415,9 @@ def prod(
14181415
axis: AxisLike = None,
14191416
dtype: DTypeLike = None,
14201417
out: Optional[OutArray] = None,
1421-
keepdims=NoValue,
1422-
initial: NotImplementedType = NoValue,
1423-
where: NotImplementedType = NoValue,
1418+
keepdims=None,
1419+
initial: NotImplementedType = None,
1420+
where: NotImplementedType = None,
14241421
):
14251422
result = _impl.prod(
14261423
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
@@ -1436,11 +1433,11 @@ def mean(
14361433
axis: AxisLike = None,
14371434
dtype: DTypeLike = None,
14381435
out: Optional[OutArray] = None,
1439-
keepdims=NoValue,
1436+
keepdims=None,
14401437
*,
1441-
where: NotImplementedType = NoValue,
1438+
where: NotImplementedType = None,
14421439
):
1443-
result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims)
1440+
result = _impl.mean(a, axis=axis, dtype=dtype, where=None, keepdims=keepdims)
14441441
return result
14451442

14461443

@@ -1450,9 +1447,9 @@ def var(
14501447
dtype: DTypeLike = None,
14511448
out: Optional[OutArray] = None,
14521449
ddof=0,
1453-
keepdims=NoValue,
1450+
keepdims=None,
14541451
*,
1455-
where: NotImplementedType = NoValue,
1452+
where: NotImplementedType = None,
14561453
):
14571454
result = _impl.var(
14581455
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
@@ -1466,9 +1463,9 @@ def std(
14661463
dtype: DTypeLike = None,
14671464
out: Optional[OutArray] = None,
14681465
ddof=0,
1469-
keepdims=NoValue,
1466+
keepdims=None,
14701467
*,
1471-
where: NotImplementedType = NoValue,
1468+
where: NotImplementedType = None,
14721469
):
14731470
result = _impl.std(
14741471
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
@@ -1481,7 +1478,7 @@ def argmin(
14811478
axis: AxisLike = None,
14821479
out: Optional[OutArray] = None,
14831480
*,
1484-
keepdims=NoValue,
1481+
keepdims=None,
14851482
):
14861483
result = _impl.argmin(a, axis=axis, keepdims=keepdims)
14871484
return result
@@ -1492,7 +1489,7 @@ def argmax(
14921489
axis: AxisLike = None,
14931490
out: Optional[OutArray] = None,
14941491
*,
1495-
keepdims=NoValue,
1492+
keepdims=None,
14961493
):
14971494
result = _impl.argmax(a, axis=axis, keepdims=keepdims)
14981495
return result
@@ -1502,9 +1499,9 @@ def amax(
15021499
a: ArrayLike,
15031500
axis: AxisLike = None,
15041501
out: Optional[OutArray] = None,
1505-
keepdims=NoValue,
1506-
initial: NotImplementedType = NoValue,
1507-
where: NotImplementedType = NoValue,
1502+
keepdims=None,
1503+
initial: NotImplementedType = None,
1504+
where: NotImplementedType = None,
15081505
):
15091506
result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
15101507
return result
@@ -1517,9 +1514,9 @@ def amin(
15171514
a: ArrayLike,
15181515
axis: AxisLike = None,
15191516
out: Optional[OutArray] = None,
1520-
keepdims=NoValue,
1521-
initial: NotImplementedType = NoValue,
1522-
where: NotImplementedType = NoValue,
1517+
keepdims=None,
1518+
initial: NotImplementedType = None,
1519+
where: NotImplementedType = None,
15231520
):
15241521
result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
15251522
return result
@@ -1532,7 +1529,7 @@ def ptp(
15321529
a: ArrayLike,
15331530
axis: AxisLike = None,
15341531
out: Optional[OutArray] = None,
1535-
keepdims=NoValue,
1532+
keepdims=None,
15361533
):
15371534
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
15381535
return result
@@ -1542,9 +1539,9 @@ def all(
15421539
a: ArrayLike,
15431540
axis: AxisLike = None,
15441541
out: Optional[OutArray] = None,
1545-
keepdims=NoValue,
1542+
keepdims=None,
15461543
*,
1547-
where: NotImplementedType = NoValue,
1544+
where: NotImplementedType = None,
15481545
):
15491546
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
15501547
return result
@@ -1554,9 +1551,9 @@ def any(
15541551
a: ArrayLike,
15551552
axis: AxisLike = None,
15561553
out: Optional[OutArray] = None,
1557-
keepdims=NoValue,
1554+
keepdims=None,
15581555
*,
1559-
where: NotImplementedType = NoValue,
1556+
where: NotImplementedType = None,
15601557
):
15611558
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
15621559
return result
@@ -1659,7 +1656,7 @@ def average(
16591656
weights: ArrayLike = None,
16601657
returned=False,
16611658
*,
1662-
keepdims=NoValue,
1659+
keepdims=None,
16631660
):
16641661
result, wsum = _impl.average(a, axis, weights, returned=returned, keepdims=keepdims)
16651662
if returned:
@@ -1672,8 +1669,8 @@ def diff(
16721669
a: ArrayLike,
16731670
n=1,
16741671
axis=-1,
1675-
prepend: Optional[ArrayLike] = NoValue,
1676-
append: Optional[ArrayLike] = NoValue,
1672+
prepend: Optional[ArrayLike] = None,
1673+
append: Optional[ArrayLike] = None,
16771674
):
16781675
axis = _util.normalize_axis_index(axis, a.ndim)
16791676

0 commit comments

Comments
 (0)