Skip to content

Commit 76de7ec

Browse files
committed
use compat validator
1 parent b3a9520 commit 76de7ec

File tree

2 files changed

+41
-63
lines changed

2 files changed

+41
-63
lines changed

pandas/core/arrays/numpy_.py

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from functools import wraps
21
import numbers
32

43
import numpy as np
54

65
from pandas._libs import lib
6+
from pandas.compat.numpy import function as nv
77

88
from pandas.core.dtypes.common import extract_array
99
from pandas.core.dtypes.dtypes import ExtensionDtype
@@ -16,32 +16,6 @@
1616
from .base import ExtensionArray, ExtensionOpsMixin
1717

1818

19-
def _validate_reduction_args(func):
20-
"""
21-
Validate that unused keyword arguments are not set.
22-
"""
23-
check = {'out', 'keepdims', 'initial', 'dtype'}
24-
25-
msg = (
26-
"'{}' does not support the '{}' parameter."
27-
)
28-
29-
@wraps(func)
30-
def wrapper(*args, **kwargs):
31-
if len(args) > 1:
32-
raise TypeError(
33-
"'{}' requires keyword arguments.".format(func.__name__)
34-
)
35-
for kw, value in kwargs.items():
36-
if kw in check and value is not None:
37-
raise TypeError(msg.format(func.__name__, kw))
38-
elif kw == 'ovewrite_input' and kw is not False:
39-
raise TypeError(msg.format(func.__name__, kw))
40-
41-
return func(*args, **kwargs)
42-
return wrapper
43-
44-
4519
class PandasDtype(ExtensionDtype):
4620
"""
4721
A Pandas ExtensionDtype for NumPy dtypes.
@@ -333,72 +307,83 @@ def _reduce(self, name, skipna=True, **kwargs):
333307
)
334308
raise TypeError(msg.format(type(self).__name__, name))
335309

336-
@_validate_reduction_args
337-
def any(self, axis=None, out=None, keepdims=None, skipna=True):
310+
# @_validate_reduction_args
311+
def any(self, axis=None, out=None, keepdims=False, skipna=True):
312+
nv.validate_any((), dict(out=out, keepdims=keepdims))
338313
return nanops.nanany(self._ndarray, axis=axis, skipna=skipna)
339314

340-
@_validate_reduction_args
341-
def all(self, axis=None, out=None, keepdims=None, skipna=True):
315+
# @_validate_reduction_args
316+
def all(self, axis=None, out=None, keepdims=False, skipna=True):
317+
nv.validate_all((), dict(out=out, keepdims=keepdims))
342318
return nanops.nanall(self._ndarray, axis=axis, skipna=skipna)
343319

344-
@_validate_reduction_args
345-
def min(self, axis=None, out=None, keepdims=None, initial=None,
346-
skipna=True):
320+
# @_validate_reduction_args
321+
def min(self, axis=None, out=None, keepdims=False, skipna=True):
322+
nv.validate_min((), dict(out=out, keepdims=keepdims))
347323
return nanops.nanmin(self._ndarray, axis=axis, skipna=skipna)
348324

349-
@_validate_reduction_args
350-
def max(self, axis=None, out=None, keepdims=None, initial=None,
351-
skipna=True):
325+
# @_validate_reduction_args
326+
def max(self, axis=None, out=None, keepdims=False, skipna=True):
327+
nv.validate_max((), dict(out=out, keepdims=keepdims))
352328
return nanops.nanmax(self._ndarray, axis=axis, skipna=skipna)
353329

354-
@_validate_reduction_args
355-
def sum(self, axis=None, dtype=None, out=None, keepdims=None,
330+
# @_validate_reduction_args
331+
def sum(self, axis=None, dtype=None, out=None, keepdims=False,
356332
initial=None, skipna=True, min_count=0):
333+
nv.validate_sum((), dict(dtype=dtype, out=out, keepdims=keepdims,
334+
initial=initial))
357335
return nanops.nansum(self._ndarray, axis=axis, skipna=skipna,
358336
min_count=min_count)
359337

360-
@_validate_reduction_args
361-
def prod(self, axis=None, dtype=None, out=None, keepdims=None,
338+
def prod(self, axis=None, dtype=None, out=None, keepdims=False,
362339
initial=None, skipna=True, min_count=0):
340+
nv.validate_prod((), dict(dtype=dtype, out=out, keepdims=keepdims,
341+
initial=initial))
363342
return nanops.nanprod(self._ndarray, axis=axis, skipna=skipna,
364343
min_count=min_count)
365344

366-
@_validate_reduction_args
367-
def mean(self, axis=None, dtype=None, out=None, keepdims=None,
345+
def mean(self, axis=None, dtype=None, out=None, keepdims=False,
368346
skipna=True):
347+
nv.validate_mean((), dict(dtype=dtype, out=out, keepdims=keepdims))
369348
return nanops.nanmean(self._ndarray, axis=axis, skipna=skipna)
370349

371-
@_validate_reduction_args
372350
def median(self, axis=None, out=None, overwrite_input=False,
373351
keepdims=False, skipna=True):
352+
nv.validate_median((), dict(out=out, overwrite_input=overwrite_input,
353+
keepdims=keepdims))
374354
return nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna)
375355

376-
@_validate_reduction_args
377-
def std(self, axis=None, dtype=None, out=None, ddof=1, keepdims=None,
356+
def std(self, axis=None, dtype=None, out=None, ddof=1, keepdims=False,
378357
skipna=True):
358+
nv.validate_stat_ddof_func((), dict(dtype=dtype, out=out,
359+
keepdims=keepdims))
379360
return nanops.nanstd(self._ndarray, axis=axis, skipna=skipna,
380361
ddof=ddof)
381362

382-
@_validate_reduction_args
383-
def var(self, axis=None, dtype=None, out=None, ddof=1, keepdims=None,
363+
def var(self, axis=None, dtype=None, out=None, ddof=1, keepdims=False,
384364
skipna=True):
365+
nv.validate_stat_ddof_func((), dict(dtype=dtype, out=out,
366+
keepdims=keepdims))
385367
return nanops.nanvar(self._ndarray, axis=axis, skipna=skipna,
386368
ddof=ddof)
387369

388-
@_validate_reduction_args
389-
def sem(self, axis=None, dtype=None, out=None, ddof=1, keepdims=None,
370+
def sem(self, axis=None, dtype=None, out=None, ddof=1, keepdims=False,
390371
skipna=True):
372+
nv.validate_stat_ddof_func((), dict(dtype=dtype, out=out,
373+
keepdims=keepdims))
391374
return nanops.nansem(self._ndarray, axis=axis, skipna=skipna,
392375
ddof=ddof)
393376

394-
@_validate_reduction_args
395-
def kurt(self, axis=None, dtype=None, out=None, keepdims=None,
377+
def kurt(self, axis=None, dtype=None, out=None, keepdims=False,
396378
skipna=True):
379+
nv.validate_stat_ddof_func((), dict(dtype=dtype, out=out,
380+
keepdims=keepdims))
397381
return nanops.nankurt(self._ndarray, axis=axis, skipna=skipna)
398382

399-
@_validate_reduction_args
400-
def skew(self, axis=None, dtype=None, out=None, keepdims=None,
383+
def skew(self, axis=None, dtype=None, out=None, keepdims=False,
401384
skipna=True):
385+
nv.validate_stat_ddof_func((), dict(dtype=dtype, out=out,
386+
keepdims=keepdims))
402387
return nanops.nanskew(self._ndarray, axis=axis, skipna=skipna)
403388

404389
# ------------------------------------------------------------------------

pandas/tests/arrays/test_numpy.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,10 @@ def test_from_sequence_dtype():
4747
tm.assert_extension_array_equal(result, expected)
4848

4949

50-
def test_validate_reduction_positional_args():
51-
arr = PandasArray(np.array([1, 2, 3]))
52-
53-
with pytest.raises(TypeError, match="'all' requires keyword arguments"):
54-
arr.all(0)
55-
56-
5750
def test_validate_reduction_keyword_args():
5851
arr = PandasArray(np.array([1, 2, 3]))
59-
msg = "'all' does not support the 'keepdims' parameter"
60-
with pytest.raises(TypeError, match=msg):
52+
msg = "the 'keepdims' parameter is not supported .*all"
53+
with pytest.raises(ValueError, match=msg):
6154
arr.all(keepdims=True)
6255

6356

0 commit comments

Comments
 (0)