Skip to content

Commit f9de80f

Browse files
committed
BUG: Prevent abuse of kwargs in stat functions
Filters kwargs argument in stat functions to prevent the passage of clearly invalid arguments while at the same time maintaining compatibility with analogous numpy functions. Closes pandas-devgh-12301.
1 parent b5aecab commit f9de80f

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

doc/source/whatsnew/v0.18.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,8 @@ Other API Changes
824824

825825
- As part of the new API for :ref:`window functions <whatsnew_0180.enhancements.moments>` and :ref:`resampling <whatsnew_0180.breaking.resample>`, aggregation functions have been clarified, raising more informative error messages on invalid aggregations. (:issue:`9052`). A full set of examples are presented in :ref:`groupby <groupby.aggregation>`.
826826

827+
- Statistical functions for ``NDFrame`` objects will now raise if non-numpy-compatible arguments are passed in for ``**kwargs`` (:issue:`12301`)
828+
827829
.. _whatsnew_0180.deprecations:
828830

829831
Deprecations

pandas/core/generic.py

+20
Original file line numberDiff line numberDiff line change
@@ -5207,12 +5207,29 @@ def _doc_parms(cls):
52075207
%(outname)s : %(name1)s\n"""
52085208

52095209

5210+
def _validate_kwargs(fname, kwargs, *compat_args):
5211+
"""
5212+
Checks whether parameters passed to the
5213+
**kwargs argument in a 'stat' function 'fname'
5214+
are valid parameters as specified in *compat_args
5215+
5216+
"""
5217+
list(map(kwargs.__delitem__, filter(
5218+
kwargs.__contains__, compat_args)))
5219+
if kwargs:
5220+
bad_arg = list(kwargs)[0] # first 'key' element
5221+
raise TypeError(("{fname}() got an unexpected "
5222+
"keyword argument '{arg}'".
5223+
format(fname=fname, arg=bad_arg)))
5224+
5225+
52105226
def _make_stat_function(name, name1, name2, axis_descr, desc, f):
52115227
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
52125228
axis_descr=axis_descr)
52135229
@Appender(_num_doc)
52145230
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
52155231
**kwargs):
5232+
_validate_kwargs(name, kwargs, 'out', 'dtype')
52165233
if skipna is None:
52175234
skipna = True
52185235
if axis is None:
@@ -5233,6 +5250,7 @@ def _make_stat_function_ddof(name, name1, name2, axis_descr, desc, f):
52335250
@Appender(_num_ddof_doc)
52345251
def stat_func(self, axis=None, skipna=None, level=None, ddof=1,
52355252
numeric_only=None, **kwargs):
5253+
_validate_kwargs(name, kwargs, 'out', 'dtype')
52365254
if skipna is None:
52375255
skipna = True
52385256
if axis is None:
@@ -5254,6 +5272,7 @@ def _make_cum_function(name, name1, name2, axis_descr, desc, accum_func,
52545272
@Appender("Return cumulative {0} over requested axis.".format(name) +
52555273
_cnum_doc)
52565274
def func(self, axis=None, dtype=None, out=None, skipna=True, **kwargs):
5275+
_validate_kwargs(name, kwargs, 'out', 'dtype')
52575276
if axis is None:
52585277
axis = self._stat_axis_number
52595278
else:
@@ -5288,6 +5307,7 @@ def _make_logical_function(name, name1, name2, axis_descr, desc, f):
52885307
@Appender(_bool_doc)
52895308
def logical_func(self, axis=None, bool_only=None, skipna=None, level=None,
52905309
**kwargs):
5310+
_validate_kwargs(name, kwargs, 'out', 'dtype')
52915311
if skipna is None:
52925312
skipna = True
52935313
if axis is None:

pandas/tests/test_generic.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
from pandas.compat import range, zip
1818
from pandas import compat
19-
from pandas.util.testing import (assert_series_equal,
19+
from pandas.util.testing import (assertRaisesRegexp,
20+
assert_series_equal,
2021
assert_frame_equal,
2122
assert_panel_equal,
2223
assert_panel4d_equal,
2324
assert_almost_equal,
2425
assert_equal)
26+
2527
import pandas.util.testing as tm
2628

2729

@@ -483,8 +485,6 @@ def test_split_compat(self):
483485
self.assertTrue(len(np.array_split(o, 2)) == 2)
484486

485487
def test_unexpected_keyword(self): # GH8597
486-
from pandas.util.testing import assertRaisesRegexp
487-
488488
df = DataFrame(np.random.randn(5, 2), columns=['jim', 'joe'])
489489
ca = pd.Categorical([0, 0, 2, 2, 3, np.nan])
490490
ts = df['joe'].copy()
@@ -502,6 +502,20 @@ def test_unexpected_keyword(self): # GH8597
502502
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
503503
ts.fillna(0, in_place=True)
504504

505+
# See gh-12301
506+
def test_stat_unexpected_keyword(self):
507+
obj = self._construct(5)
508+
starwars = 'Star Wars'
509+
510+
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
511+
obj.max(epic=starwars) # stat_function
512+
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
513+
obj.var(epic=starwars) # stat_function_ddof
514+
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
515+
obj.sum(epic=starwars) # cum_function
516+
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
517+
obj.any(epic=starwars) # logical_function
518+
505519

506520
class TestSeries(tm.TestCase, Generic):
507521
_typ = Series

0 commit comments

Comments
 (0)