Skip to content

Commit 1ab873c

Browse files
TomAugspurgerPingviinituutti
authored andcommitted
COMPAT: Add keepdims and friends to validation (pandas-dev#24356)
1 parent 9417188 commit 1ab873c

File tree

4 files changed

+67
-5
lines changed

4 files changed

+67
-5
lines changed

doc/source/whatsnew/v0.24.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,7 @@ Numeric
14001400
- Added ``log10`` to the list of supported functions in :meth:`DataFrame.eval` (:issue:`24139`)
14011401
- Logical operations ``&, |, ^`` between :class:`Series` and :class:`Index` will no longer raise ``ValueError`` (:issue:`22092`)
14021402
- Checking PEP 3141 numbers in :func:`~pandas.api.types.is_scalar` function returns ``True`` (:issue:`22903`)
1403+
- Reduction methods like :meth:`Series.sum` now accept the default value of ``keepdims=False`` when called from a NumPy ufunc, rather than raising a ``TypeError``. Full support for ``keepdims`` has not been implemented (:issue:`24356`).
14031404

14041405
Conversion
14051406
^^^^^^^^^^

pandas/compat/numpy/function.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,16 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name):
189189
ALLANY_DEFAULTS = OrderedDict()
190190
ALLANY_DEFAULTS['dtype'] = None
191191
ALLANY_DEFAULTS['out'] = None
192+
ALLANY_DEFAULTS['keepdims'] = False
192193
validate_all = CompatValidator(ALLANY_DEFAULTS, fname='all',
193194
method='both', max_fname_arg_count=1)
194195
validate_any = CompatValidator(ALLANY_DEFAULTS, fname='any',
195196
method='both', max_fname_arg_count=1)
196197

197-
LOGICAL_FUNC_DEFAULTS = dict(out=None)
198+
LOGICAL_FUNC_DEFAULTS = dict(out=None, keepdims=False)
198199
validate_logical_func = CompatValidator(LOGICAL_FUNC_DEFAULTS, method='kwargs')
199200

200-
MINMAX_DEFAULTS = dict(out=None)
201+
MINMAX_DEFAULTS = dict(out=None, keepdims=False)
201202
validate_min = CompatValidator(MINMAX_DEFAULTS, fname='min',
202203
method='both', max_fname_arg_count=1)
203204
validate_max = CompatValidator(MINMAX_DEFAULTS, fname='max',
@@ -225,16 +226,32 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name):
225226
STAT_FUNC_DEFAULTS = OrderedDict()
226227
STAT_FUNC_DEFAULTS['dtype'] = None
227228
STAT_FUNC_DEFAULTS['out'] = None
229+
230+
PROD_DEFAULTS = SUM_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
231+
SUM_DEFAULTS['keepdims'] = False
232+
SUM_DEFAULTS['initial'] = None
233+
234+
MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
235+
MEDIAN_DEFAULTS['overwrite_input'] = False
236+
MEDIAN_DEFAULTS['keepdims'] = False
237+
238+
STAT_FUNC_DEFAULTS['keepdims'] = False
239+
228240
validate_stat_func = CompatValidator(STAT_FUNC_DEFAULTS,
229241
method='kwargs')
230-
validate_sum = CompatValidator(STAT_FUNC_DEFAULTS, fname='sort',
242+
validate_sum = CompatValidator(SUM_DEFAULTS, fname='sum',
231243
method='both', max_fname_arg_count=1)
244+
validate_prod = CompatValidator(PROD_DEFAULTS, fname="prod",
245+
method="both", max_fname_arg_count=1)
232246
validate_mean = CompatValidator(STAT_FUNC_DEFAULTS, fname='mean',
233247
method='both', max_fname_arg_count=1)
248+
validate_median = CompatValidator(MEDIAN_DEFAULTS, fname='median',
249+
method='both', max_fname_arg_count=1)
234250

235251
STAT_DDOF_FUNC_DEFAULTS = OrderedDict()
236252
STAT_DDOF_FUNC_DEFAULTS['dtype'] = None
237253
STAT_DDOF_FUNC_DEFAULTS['out'] = None
254+
STAT_DDOF_FUNC_DEFAULTS['keepdims'] = False
238255
validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS,
239256
method='kwargs')
240257

pandas/core/generic.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -10834,7 +10834,12 @@ def _make_min_count_stat_function(cls, name, name1, name2, axis_descr, desc,
1083410834
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
1083510835
min_count=0,
1083610836
**kwargs):
10837-
nv.validate_stat_func(tuple(), kwargs, fname=name)
10837+
if name == 'sum':
10838+
nv.validate_sum(tuple(), kwargs)
10839+
elif name == 'prod':
10840+
nv.validate_prod(tuple(), kwargs)
10841+
else:
10842+
nv.validate_stat_func(tuple(), kwargs, fname=name)
1083810843
if skipna is None:
1083910844
skipna = True
1084010845
if axis is None:
@@ -10855,7 +10860,10 @@ def _make_stat_function(cls, name, name1, name2, axis_descr, desc, f,
1085510860
@Appender(_num_doc)
1085610861
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
1085710862
**kwargs):
10858-
nv.validate_stat_func(tuple(), kwargs, fname=name)
10863+
if name == 'median':
10864+
nv.validate_median(tuple(), kwargs)
10865+
else:
10866+
nv.validate_stat_func(tuple(), kwargs, fname=name)
1085910867
if skipna is None:
1086010868
skipna = True
1086110869
if axis is None:

pandas/tests/series/test_analytics.py

+36
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,42 @@ def test_value_counts_categorical_not_ordered(self):
16411641
tm.assert_series_equal(s.value_counts(normalize=True), exp)
16421642
tm.assert_series_equal(idx.value_counts(normalize=True), exp)
16431643

1644+
@pytest.mark.parametrize("func", [np.any, np.all])
1645+
@pytest.mark.parametrize("kwargs", [
1646+
dict(keepdims=True),
1647+
dict(out=object()),
1648+
])
1649+
@td.skip_if_np_lt_115
1650+
def test_validate_any_all_out_keepdims_raises(self, kwargs, func):
1651+
s = pd.Series([1, 2])
1652+
param = list(kwargs)[0]
1653+
name = func.__name__
1654+
1655+
msg = "the '{}' parameter .* {}".format(param, name)
1656+
with pytest.raises(ValueError, match=msg):
1657+
func(s, **kwargs)
1658+
1659+
@td.skip_if_np_lt_115
1660+
def test_validate_sum_initial(self):
1661+
s = pd.Series([1, 2])
1662+
with pytest.raises(ValueError, match="the 'initial' .* sum"):
1663+
np.sum(s, initial=10)
1664+
1665+
def test_validate_median_initial(self):
1666+
s = pd.Series([1, 2])
1667+
with pytest.raises(ValueError,
1668+
match="the 'overwrite_input' .* median"):
1669+
# It seems like np.median doesn't dispatch, so we use the
1670+
# method instead of the ufunc.
1671+
s.median(overwrite_input=True)
1672+
1673+
@td.skip_if_np_lt_115
1674+
def test_validate_stat_keepdims(self):
1675+
s = pd.Series([1, 2])
1676+
with pytest.raises(ValueError,
1677+
match="the 'keepdims'"):
1678+
np.sum(s, keepdims=True)
1679+
16441680

16451681
main_dtypes = [
16461682
'datetime',

0 commit comments

Comments
 (0)