Skip to content

Commit 16d2f59

Browse files
authored
BUG: nlargest/nsmallest can now consider nan values like sort_values(ascending=True).head(n) (#43060)
1 parent d8e1ba5 commit 16d2f59

File tree

5 files changed

+34
-5
lines changed

5 files changed

+34
-5
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ Indexing
361361
- Bug in :meth:`Index.get_indexer_non_unique` when index contains multiple ``np.nan`` (:issue:`35392`)
362362
- Bug in :meth:`DataFrame.query` did not handle the degree sign in a backticked column name, such as \`Temp(°C)\`, used in an expression to query a dataframe (:issue:`42826`)
363363
- Bug in :meth:`DataFrame.drop` where the error message did not show missing labels with commas when raising ``KeyError`` (:issue:`42881`)
364-
-
364+
- Bug in :meth:`DataFrame.nlargest` and :meth:`Series.nlargest` where sorted result did not count indexes containing ``np.nan`` (:issue:`28984`)
365365

366366

367367
Missing

pandas/core/algorithms.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,8 @@ class SelectNSeries(SelectN):
12521252

12531253
def compute(self, method: str) -> Series:
12541254

1255+
from pandas.core.reshape.concat import concat
1256+
12551257
n = self.n
12561258
dtype = self.obj.dtype
12571259
if not self.is_valid_dtype_n_method(dtype):
@@ -1261,6 +1263,7 @@ def compute(self, method: str) -> Series:
12611263
return self.obj[[]]
12621264

12631265
dropped = self.obj.dropna()
1266+
nan_index = self.obj.drop(dropped.index)
12641267

12651268
if is_extension_array_dtype(dropped.dtype):
12661269
# GH#41816 bc we have dropped NAs above, MaskedArrays can use the
@@ -1277,7 +1280,7 @@ def compute(self, method: str) -> Series:
12771280
# slow method
12781281
if n >= len(self.obj):
12791282
ascending = method == "nsmallest"
1280-
return dropped.sort_values(ascending=ascending).head(n)
1283+
return self.obj.sort_values(ascending=ascending).head(n)
12811284

12821285
# fast method
12831286
new_dtype = dropped.dtype
@@ -1295,6 +1298,8 @@ def compute(self, method: str) -> Series:
12951298
if self.keep == "last":
12961299
arr = arr[::-1]
12971300

1301+
nbase = n
1302+
findex = len(self.obj)
12981303
narr = len(arr)
12991304
n = min(n, narr)
13001305

@@ -1306,12 +1311,13 @@ def compute(self, method: str) -> Series:
13061311

13071312
if self.keep != "all":
13081313
inds = inds[:n]
1314+
findex = nbase
13091315

13101316
if self.keep == "last":
13111317
# reverse indices
13121318
inds = narr - 1 - inds
13131319

1314-
return dropped.iloc[inds]
1320+
return concat([dropped.iloc[inds], nan_index]).iloc[:findex]
13151321

13161322

13171323
class SelectNFrame(SelectN):

pandas/tests/frame/methods/test_nlargest.py

+7
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,10 @@ def test_nlargest_multiindex_column_lookup(self):
209209
result = df.nlargest(3, ("x", "b"))
210210
expected = df.iloc[[3, 2, 1]]
211211
tm.assert_frame_equal(result, expected)
212+
213+
def test_nlargest_nan(self):
214+
# GH#43060
215+
df = pd.DataFrame([np.nan, np.nan, 0, 1, 2, 3])
216+
result = df.nlargest(5, 0)
217+
expected = df.sort_values(0, ascending=False).head(5)
218+
tm.assert_frame_equal(result, expected)

pandas/tests/groupby/test_apply.py

+12
Original file line numberDiff line numberDiff line change
@@ -1145,3 +1145,15 @@ def test_doctest_example2():
11451145
{"B": [1.0, 0.0], "C": [2.0, 0.0]}, index=Index(["a", "b"], name="A")
11461146
)
11471147
tm.assert_frame_equal(result, expected)
1148+
1149+
1150+
@pytest.mark.parametrize("dropna", [True, False])
1151+
def test_apply_na(dropna):
1152+
# GH#28984
1153+
df = DataFrame(
1154+
{"grp": [1, 1, 2, 2], "y": [1, 0, 2, 5], "z": [1, 2, np.nan, np.nan]}
1155+
)
1156+
dfgrp = df.groupby("grp", dropna=dropna)
1157+
result = dfgrp.apply(lambda grp_df: grp_df.nlargest(1, "z"))
1158+
expected = dfgrp.apply(lambda x: x.sort_values("z", ascending=False).head(1))
1159+
tm.assert_frame_equal(result, expected)

pandas/tests/series/methods/test_nlargest.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,12 @@ def test_nsmallest_nlargest(self, s_main_dtypes_split):
127127
def test_nlargest_misc(self):
128128

129129
ser = Series([3.0, np.nan, 1, 2, 5])
130-
tm.assert_series_equal(ser.nlargest(), ser.iloc[[4, 0, 3, 2]])
131-
tm.assert_series_equal(ser.nsmallest(), ser.iloc[[2, 3, 0, 4]])
130+
result = ser.nlargest()
131+
expected = ser.iloc[[4, 0, 3, 2, 1]]
132+
tm.assert_series_equal(result, expected)
133+
result = ser.nsmallest()
134+
expected = ser.iloc[[2, 3, 0, 4, 1]]
135+
tm.assert_series_equal(result, expected)
132136

133137
msg = 'keep must be either "first", "last"'
134138
with pytest.raises(ValueError, match=msg):

0 commit comments

Comments
 (0)