Skip to content

Commit f841ef8

Browse files
committed
BUG: Fix problems in group rank when both nans and infinity are present (#20561)
1 parent 2794474 commit f841ef8

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

pandas/_libs/groupby_helper.pxi.in

+5-4
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,8 @@ def group_rank_{{name}}(ndarray[float64_t, ndim=2] out,
508508

509509
# if keep_na, check for missing values and assign back
510510
# to the result where appropriate
511-
if keep_na and masked_vals[_as[i]] == nan_fill_val:
511+
512+
if keep_na and mask[_as[i]]:
512513
grp_na_count += 1
513514
out[_as[i], 0] = nan
514515
else:
@@ -548,9 +549,9 @@ def group_rank_{{name}}(ndarray[float64_t, ndim=2] out,
548549
# reset the dups and sum_ranks, knowing that a new value is coming
549550
# up. the conditional also needs to handle nan equality and the
550551
# end of iteration
551-
if (i == N - 1 or (
552-
(masked_vals[_as[i]] != masked_vals[_as[i+1]]) and not
553-
(mask[_as[i]] and mask[_as[i+1]]))):
552+
if (i == N - 1 or
553+
(masked_vals[_as[i]] != masked_vals[_as[i+1]]) or
554+
(mask[_as[i]] ^ mask[_as[i+1]])):
554555
dups = sum_ranks = 0
555556
val_start = i
556557
grp_vals_seen += 1

pandas/tests/groupby/test_groupby.py

+49
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,55 @@ def test_rank_args(self, grps, vals, ties_method, ascending, pct, exp):
19651965
exp_df = DataFrame(exp * len(grps), columns=['val'])
19661966
assert_frame_equal(result, exp_df)
19671967

1968+
@pytest.mark.parametrize("grps", [
1969+
['qux'], ['qux', 'quux']])
1970+
@pytest.mark.parametrize("vals", [
1971+
[-np.inf, -np.inf, np.nan, 1., np.nan, np.inf, np.inf],
1972+
])
1973+
@pytest.mark.parametrize("ties_method,ascending,na_option,exp", [
1974+
('average', True, 'keep', [1.5, 1.5, np.nan, 3, np.nan, 4.5, 4.5]),
1975+
('average', True, 'top', [3.5, 3.5, 1.5, 5., 1.5, 6.5, 6.5]),
1976+
('average', True, 'bottom', [1.5, 1.5, 6.5, 3., 6.5, 4.5, 4.5]),
1977+
('average', False, 'keep', [1.5, 1.5, np.nan, 3, np.nan, 4.5, 4.5
1978+
][::-1]),
1979+
('average', False, 'top', [3.5, 3.5, 1.5, 5., 1.5, 6.5, 6.5][::-1]),
1980+
('average', False, 'bottom', [1.5, 1.5, 6.5, 3., 6.5, 4.5, 4.5][::-1]),
1981+
('min', True, 'keep', [1., 1., np.nan, 3., np.nan, 4., 4.]),
1982+
('min', True, 'top', [3., 3., 1., 5., 1., 6., 6.]),
1983+
('min', True, 'bottom', [1., 1., 6., 3., 6., 4., 4.]),
1984+
('min', False, 'keep', [1., 1., np.nan, 3., np.nan, 4., 4.][::-1]),
1985+
('min', False, 'top', [3., 3., 1., 5., 1., 6., 6.][::-1]),
1986+
('min', False, 'bottom', [1., 1., 6., 3., 6., 4., 4.][::-1]),
1987+
('max', True, 'keep', [2., 2., np.nan, 3., np.nan, 5., 5.]),
1988+
('max', True, 'top', [4., 4., 2., 5., 2., 7., 7.]),
1989+
('max', True, 'bottom', [2., 2., 7., 3., 7., 5., 5.]),
1990+
('max', False, 'keep', [2., 2., np.nan, 3., np.nan, 5., 5.][::-1]),
1991+
('max', False, 'top', [4., 4., 2., 5., 2., 7., 7.][::-1]),
1992+
('max', False, 'bottom', [2., 2., 7., 3., 7., 5., 5.][::-1]),
1993+
('first', True, 'keep', [1., 2., np.nan, 3., np.nan, 4., 5.]),
1994+
('first', True, 'top', [3., 4., 1., 5., 2., 6., 7.]),
1995+
('first', True, 'bottom', [1., 2., 6., 3., 7., 4., 5.]),
1996+
('first', False, 'keep', [4., 5., np.nan, 3., np.nan, 1., 2.]),
1997+
('first', False, 'top', [6., 7., 1., 5., 2., 3., 4.]),
1998+
('first', False, 'bottom', [4., 5., 6., 3., 7., 1., 2.]),
1999+
('dense', True, 'keep', [1., 1., np.nan, 2., np.nan, 3., 3.]),
2000+
('dense', True, 'top', [2., 2., 1., 3., 1., 4., 4.]),
2001+
('dense', True, 'bottom', [1., 1., 4., 2., 4., 3., 3.]),
2002+
('dense', False, 'keep', [3., 3., np.nan, 2., np.nan, 1., 1.]),
2003+
('dense', False, 'top', [4., 4., 1., 3., 1., 2., 2.]),
2004+
('dense', False, 'bottom', [3., 3., 4., 2., 4., 1., 1.]),
2005+
])
2006+
def test_infs_n_nans(self, grps, vals, ties_method, ascending, na_option,
2007+
exp):
2008+
key = np.repeat(grps, len(vals))
2009+
vals = vals * len(grps)
2010+
df = DataFrame({'key': key, 'val': vals})
2011+
result = df.groupby('key').rank(method=ties_method,
2012+
ascending=ascending,
2013+
na_option=na_option)
2014+
exp_df = DataFrame(exp * len(grps), columns=['val'])
2015+
assert_frame_equal(result, exp_df)
2016+
19682017
@pytest.mark.parametrize("grps", [
19692018
['qux'], ['qux', 'quux']])
19702019
@pytest.mark.parametrize("vals", [

0 commit comments

Comments
 (0)