Skip to content

Commit 9250124

Browse files
authored
Merge pull request #114 from Quansight-Labs/notimpl_annotation
MAINT: handle not implemented arguments centrally, via a dedicated annotation
2 parents 5a7e423 + e2bd8d9 commit 9250124

File tree

6 files changed

+166
-206
lines changed

6 files changed

+166
-206
lines changed

torch_np/_detail/_reductions.py

Lines changed: 14 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,13 @@
44
Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
55
"""
66

7+
import functools
78
import typing
89

910
import torch
1011

1112
from . import _dtypes_impl, _util
1213

13-
NoValue = _util.NoValue
14-
15-
16-
import functools
17-
1814
############# XXX
1915
### From _util.axis_expand_func
2016

@@ -51,7 +47,7 @@ def wrapped(tensor, axis, *args, **kwds):
5147

5248
def emulate_keepdims(func):
5349
@functools.wraps(func)
54-
def wrapped(tensor, axis=None, keepdims=NoValue, *args, **kwds):
50+
def wrapped(tensor, axis=None, keepdims=None, *args, **kwds):
5551
result = func(tensor, axis=axis, *args, **kwds)
5652
if keepdims:
5753
result = _util.apply_keepdims(result, axis, tensor.ndim)
@@ -133,10 +129,7 @@ def argmin(tensor, axis=None):
133129

134130
@emulate_keepdims
135131
@deco_axis_expand
136-
def any(tensor, axis=None, *, where=NoValue):
137-
if where is not NoValue:
138-
raise NotImplementedError
139-
132+
def any(tensor, axis=None, *, where=None):
140133
axis = _util.allow_only_single_axis(axis)
141134

142135
if axis is None:
@@ -148,10 +141,7 @@ def any(tensor, axis=None, *, where=NoValue):
148141

149142
@emulate_keepdims
150143
@deco_axis_expand
151-
def all(tensor, axis=None, *, where=NoValue):
152-
if where is not NoValue:
153-
raise NotImplementedError
154-
144+
def all(tensor, axis=None, *, where=None):
155145
axis = _util.allow_only_single_axis(axis)
156146

157147
if axis is None:
@@ -163,37 +153,25 @@ def all(tensor, axis=None, *, where=NoValue):
163153

164154
@emulate_keepdims
165155
@deco_axis_expand
166-
def max(tensor, axis=None, initial=NoValue, where=NoValue):
167-
if initial is not NoValue or where is not NoValue:
168-
raise NotImplementedError
169-
170-
result = tensor.amax(axis)
171-
return result
156+
def max(tensor, axis=None, initial=None, where=None):
157+
return tensor.amax(axis)
172158

173159

174160
@emulate_keepdims
175161
@deco_axis_expand
176-
def min(tensor, axis=None, initial=NoValue, where=NoValue):
177-
if initial is not NoValue or where is not NoValue:
178-
raise NotImplementedError
179-
180-
result = tensor.amin(axis)
181-
return result
162+
def min(tensor, axis=None, initial=None, where=None):
163+
return tensor.amin(axis)
182164

183165

184166
@emulate_keepdims
185167
@deco_axis_expand
186168
def ptp(tensor, axis=None):
187-
result = tensor.amax(axis) - tensor.amin(axis)
188-
return result
169+
return tensor.amax(axis) - tensor.amin(axis)
189170

190171

191172
@emulate_keepdims
192173
@deco_axis_expand
193-
def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
194-
if initial is not NoValue or where is not NoValue:
195-
raise NotImplementedError
196-
174+
def sum(tensor, axis=None, dtype=None, initial=None, where=None):
197175
assert dtype is None or isinstance(dtype, torch.dtype)
198176

199177
if dtype == torch.bool:
@@ -209,10 +187,7 @@ def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
209187

210188
@emulate_keepdims
211189
@deco_axis_expand
212-
def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
213-
if initial is not NoValue or where is not NoValue:
214-
raise NotImplementedError
215-
190+
def prod(tensor, axis=None, dtype=None, initial=None, where=None):
216191
axis = _util.allow_only_single_axis(axis)
217192

218193
if dtype == torch.bool:
@@ -228,10 +203,7 @@ def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
228203

229204
@emulate_keepdims
230205
@deco_axis_expand
231-
def mean(tensor, axis=None, dtype=None, *, where=NoValue):
232-
if where is not NoValue:
233-
raise NotImplementedError
234-
206+
def mean(tensor, axis=None, dtype=None, *, where=None):
235207
dtype = _atleast_float(dtype, tensor.dtype)
236208

237209
is_half = dtype == torch.float16
@@ -252,10 +224,7 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue):
252224

253225
@emulate_keepdims
254226
@deco_axis_expand
255-
def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
256-
if where is not NoValue:
257-
raise NotImplementedError
258-
227+
def std(tensor, axis=None, dtype=None, ddof=0, *, where=None):
259228
dtype = _atleast_float(dtype, tensor.dtype)
260229
tensor = _util.cast_if_needed(tensor, dtype)
261230
result = tensor.std(dim=axis, correction=ddof)
@@ -265,10 +234,7 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
265234

266235
@emulate_keepdims
267236
@deco_axis_expand
268-
def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
269-
if where is not NoValue:
270-
raise NotImplementedError
271-
237+
def var(tensor, axis=None, dtype=None, ddof=0, *, where=None):
272238
dtype = _atleast_float(dtype, tensor.dtype)
273239
tensor = _util.cast_if_needed(tensor, dtype)
274240
result = tensor.var(dim=axis, correction=ddof)
@@ -387,9 +353,6 @@ def quantile(
387353
# Here we choose to work out-of-place because why not.
388354
pass
389355

390-
if interpolation is not None:
391-
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
392-
393356
if not a.dtype.is_floating_point:
394357
dtype = _dtypes_impl.default_float_dtype
395358
a = a.to(dtype)

torch_np/_detail/_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from . import _dtypes_impl
99

10-
NoValue = None
1110

1211
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
1312
def is_sequence(seq):

0 commit comments

Comments
 (0)