Skip to content

Commit 0dfeb24

Browse files
committed
MAINT: handle not implemented arguments centrally, via a special annotation
{ab,re}use NotImplemented singleton as the annotation.
1 parent 4199c18 commit 0dfeb24

File tree

4 files changed

+79
-160
lines changed

4 files changed

+79
-160
lines changed

torch_np/_detail/_reductions.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,6 @@ def argmin(tensor, axis=None):
134134
@emulate_keepdims
135135
@deco_axis_expand
136136
def any(tensor, axis=None, *, where=NoValue):
137-
if where is not NoValue:
138-
raise NotImplementedError
139-
140137
axis = _util.allow_only_single_axis(axis)
141138

142139
if axis is None:
@@ -149,9 +146,6 @@ def any(tensor, axis=None, *, where=NoValue):
149146
@emulate_keepdims
150147
@deco_axis_expand
151148
def all(tensor, axis=None, *, where=NoValue):
152-
if where is not NoValue:
153-
raise NotImplementedError
154-
155149
axis = _util.allow_only_single_axis(axis)
156150

157151
if axis is None:
@@ -164,36 +158,24 @@ def all(tensor, axis=None, *, where=NoValue):
164158
@emulate_keepdims
165159
@deco_axis_expand
166160
def max(tensor, axis=None, initial=NoValue, where=NoValue):
167-
if initial is not NoValue or where is not NoValue:
168-
raise NotImplementedError
169-
170-
result = tensor.amax(axis)
171-
return result
161+
return tensor.amax(axis)
172162

173163

174164
@emulate_keepdims
175165
@deco_axis_expand
176166
def min(tensor, axis=None, initial=NoValue, where=NoValue):
177-
if initial is not NoValue or where is not NoValue:
178-
raise NotImplementedError
179-
180-
result = tensor.amin(axis)
181-
return result
167+
return tensor.amin(axis)
182168

183169

184170
@emulate_keepdims
185171
@deco_axis_expand
186172
def ptp(tensor, axis=None):
187-
result = tensor.amax(axis) - tensor.amin(axis)
188-
return result
173+
return tensor.amax(axis) - tensor.amin(axis)
189174

190175

191176
@emulate_keepdims
192177
@deco_axis_expand
193178
def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
194-
if initial is not NoValue or where is not NoValue:
195-
raise NotImplementedError
196-
197179
assert dtype is None or isinstance(dtype, torch.dtype)
198180

199181
if dtype == torch.bool:
@@ -210,9 +192,6 @@ def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
210192
@emulate_keepdims
211193
@deco_axis_expand
212194
def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
213-
if initial is not NoValue or where is not NoValue:
214-
raise NotImplementedError
215-
216195
axis = _util.allow_only_single_axis(axis)
217196

218197
if dtype == torch.bool:
@@ -229,9 +208,6 @@ def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
229208
@emulate_keepdims
230209
@deco_axis_expand
231210
def mean(tensor, axis=None, dtype=None, *, where=NoValue):
232-
if where is not NoValue:
233-
raise NotImplementedError
234-
235211
dtype = _atleast_float(dtype, tensor.dtype)
236212

237213
is_half = dtype == torch.float16
@@ -253,9 +229,6 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue):
253229
@emulate_keepdims
254230
@deco_axis_expand
255231
def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
256-
if where is not NoValue:
257-
raise NotImplementedError
258-
259232
dtype = _atleast_float(dtype, tensor.dtype)
260233
tensor = _util.cast_if_needed(tensor, dtype)
261234
result = tensor.std(dim=axis, correction=ddof)
@@ -266,9 +239,6 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
266239
@emulate_keepdims
267240
@deco_axis_expand
268241
def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
269-
if where is not NoValue:
270-
raise NotImplementedError
271-
272242
dtype = _atleast_float(dtype, tensor.dtype)
273243
tensor = _util.cast_if_needed(tensor, dtype)
274244
result = tensor.var(dim=axis, correction=ddof)
@@ -387,9 +357,6 @@ def quantile(
387357
# Here we choose to work out-of-place because why not.
388358
pass
389359

390-
if interpolation is not None:
391-
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
392-
393360
if not a.dtype.is_floating_point:
394361
dtype = _dtypes_impl.default_float_dtype
395362
a = a.to(dtype)

0 commit comments

Comments
 (0)