Skip to content

Commit 9e83338

Browse files
committed
BUG: Better handling of invalid na_option argument for groupby.rank(#22124)
1 parent c272c52 commit 9e83338

File tree

3 files changed

+40
-27
lines changed

3 files changed

+40
-27
lines changed

doc/source/whatsnew/v0.24.0.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ Reshaping
615615
- Bug in :meth:`Series.combine_first` with ``datetime64[ns, tz]`` dtype which would return tz-naive result (:issue:`21469`)
616616
- Bug in :meth:`Series.where` and :meth:`DataFrame.where` with ``datetime64[ns, tz]`` dtype (:issue:`21546`)
617617
- Bug in :meth:`Series.mask` and :meth:`DataFrame.mask` with ``list`` conditionals (:issue:`21891`)
618-
-
618+
- :func: `pandas.core.groupby.GroupBy.rank` now raises a ValueError when an invalid value is passed for argument ``na_option`` (:issue:`22124`)
619619
-
620620

621621
Build Changes

pandas/core/groupby/groupby.py

+3
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,9 @@ def rank(self, method='average', ascending=True, na_option='keep',
17051705
-----
17061706
DataFrame with ranking of values within each group
17071707
"""
1708+
if na_option not in {'keep', 'top', 'bottom'}:
1709+
msg = "na_option must be one of 'keep', 'top', or 'bottom'"
1710+
raise ValueError(msg)
17081711
return self._cython_transform('rank', numeric_only=False,
17091712
ties_method=method, ascending=ascending,
17101713
na_option=na_option, pct=pct, axis=axis)

pandas/tests/groupby/test_rank.py

+36-26
Original file line numberDiff line numberDiff line change
@@ -172,35 +172,35 @@ def test_infs_n_nans(grps, vals, ties_method, ascending, na_option, exp):
172172
[3., 3., np.nan, 1., 3., 2., np.nan, np.nan]),
173173
('dense', False, 'keep', True,
174174
[3. / 3., 3. / 3., np.nan, 1. / 3., 3. / 3., 2. / 3., np.nan, np.nan]),
175-
('average', True, 'no_na', False, [2., 2., 7., 5., 2., 4., 7., 7.]),
176-
('average', True, 'no_na', True,
175+
('average', True, 'bottom', False, [2., 2., 7., 5., 2., 4., 7., 7.]),
176+
('average', True, 'bottom', True,
177177
[0.25, 0.25, 0.875, 0.625, 0.25, 0.5, 0.875, 0.875]),
178-
('average', False, 'no_na', False, [4., 4., 7., 1., 4., 2., 7., 7.]),
179-
('average', False, 'no_na', True,
178+
('average', False, 'bottom', False, [4., 4., 7., 1., 4., 2., 7., 7.]),
179+
('average', False, 'bottom', True,
180180
[0.5, 0.5, 0.875, 0.125, 0.5, 0.25, 0.875, 0.875]),
181-
('min', True, 'no_na', False, [1., 1., 6., 5., 1., 4., 6., 6.]),
182-
('min', True, 'no_na', True,
181+
('min', True, 'bottom', False, [1., 1., 6., 5., 1., 4., 6., 6.]),
182+
('min', True, 'bottom', True,
183183
[0.125, 0.125, 0.75, 0.625, 0.125, 0.5, 0.75, 0.75]),
184-
('min', False, 'no_na', False, [3., 3., 6., 1., 3., 2., 6., 6.]),
185-
('min', False, 'no_na', True,
184+
('min', False, 'bottom', False, [3., 3., 6., 1., 3., 2., 6., 6.]),
185+
('min', False, 'bottom', True,
186186
[0.375, 0.375, 0.75, 0.125, 0.375, 0.25, 0.75, 0.75]),
187-
('max', True, 'no_na', False, [3., 3., 8., 5., 3., 4., 8., 8.]),
188-
('max', True, 'no_na', True,
187+
('max', True, 'bottom', False, [3., 3., 8., 5., 3., 4., 8., 8.]),
188+
('max', True, 'bottom', True,
189189
[0.375, 0.375, 1., 0.625, 0.375, 0.5, 1., 1.]),
190-
('max', False, 'no_na', False, [5., 5., 8., 1., 5., 2., 8., 8.]),
191-
('max', False, 'no_na', True,
190+
('max', False, 'bottom', False, [5., 5., 8., 1., 5., 2., 8., 8.]),
191+
('max', False, 'bottom', True,
192192
[0.625, 0.625, 1., 0.125, 0.625, 0.25, 1., 1.]),
193-
('first', True, 'no_na', False, [1., 2., 6., 5., 3., 4., 7., 8.]),
194-
('first', True, 'no_na', True,
193+
('first', True, 'bottom', False, [1., 2., 6., 5., 3., 4., 7., 8.]),
194+
('first', True, 'bottom', True,
195195
[0.125, 0.25, 0.75, 0.625, 0.375, 0.5, 0.875, 1.]),
196-
('first', False, 'no_na', False, [3., 4., 6., 1., 5., 2., 7., 8.]),
197-
('first', False, 'no_na', True,
196+
('first', False, 'bottom', False, [3., 4., 6., 1., 5., 2., 7., 8.]),
197+
('first', False, 'bottom', True,
198198
[0.375, 0.5, 0.75, 0.125, 0.625, 0.25, 0.875, 1.]),
199-
('dense', True, 'no_na', False, [1., 1., 4., 3., 1., 2., 4., 4.]),
200-
('dense', True, 'no_na', True,
199+
('dense', True, 'bottom', False, [1., 1., 4., 3., 1., 2., 4., 4.]),
200+
('dense', True, 'bottom', True,
201201
[0.25, 0.25, 1., 0.75, 0.25, 0.5, 1., 1.]),
202-
('dense', False, 'no_na', False, [3., 3., 4., 1., 3., 2., 4., 4.]),
203-
('dense', False, 'no_na', True,
202+
('dense', False, 'bottom', False, [3., 3., 4., 1., 3., 2., 4., 4.]),
203+
('dense', False, 'bottom', True,
204204
[0.75, 0.75, 1., 0.25, 0.75, 0.5, 1., 1.])
205205
])
206206
def test_rank_args_missing(grps, vals, ties_method, ascending,
@@ -252,14 +252,24 @@ def test_rank_object_raises(ties_method, ascending, na_option,
252252
with tm.assert_raises_regex(TypeError, "not callable"):
253253
df.groupby('key').rank(method=ties_method,
254254
ascending=ascending,
255-
na_option='bad', pct=pct)
255+
na_option=na_option, pct=pct)
256256

257-
with tm.assert_raises_regex(TypeError, "not callable"):
258-
df.groupby('key').rank(method=ties_method,
259-
ascending=ascending,
260-
na_option=True, pct=pct)
261257

262-
with tm.assert_raises_regex(TypeError, "not callable"):
258+
@pytest.mark.parametrize("na_option", [True, "bad", 1])
259+
@pytest.mark.parametrize("ties_method", [
260+
'average', 'min', 'max', 'first', 'dense'])
261+
@pytest.mark.parametrize("ascending", [True, False])
262+
@pytest.mark.parametrize("pct", [True, False])
263+
@pytest.mark.parametrize("vals", [
264+
['bar', 'bar', 'foo', 'bar', 'baz'],
265+
['bar', np.nan, 'foo', np.nan, 'baz'],
266+
[1, np.nan, 2, np.nan, 3]
267+
])
268+
def test_rank_naoption_raises(ties_method, ascending, na_option, pct, vals):
269+
df = DataFrame({'key': ['foo'] * 5, 'val': vals})
270+
msg = "na_option must be one of 'keep', 'top', or 'bottom'"
271+
272+
with tm.assert_raises_regex(ValueError, msg):
263273
df.groupby('key').rank(method=ties_method,
264274
ascending=ascending,
265275
na_option=na_option, pct=pct)

0 commit comments

Comments
 (0)