Skip to content

Commit 83f9cbf

Browse files
Raise error for string types in nsmallest and nlargest (#13946)
closes #13945 This PR contains changes that raises an error message exactly matching pandas for `nsmallest` and `nlargest`. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Bradley Dice (https://github.com/bdice) URL: #13946
1 parent 6ed42d7 commit 83f9cbf

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

python/cudf/cudf/core/indexed_frame.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,6 +2465,20 @@ def _n_largest_or_smallest(self, largest, n, columns, keep):
24652465
if isinstance(columns, str):
24662466
columns = [columns]
24672467

2468+
method = "nlargest" if largest else "nsmallest"
2469+
for col in columns:
2470+
if isinstance(self._data[col], cudf.core.column.StringColumn):
2471+
if isinstance(self, cudf.DataFrame):
2472+
error_msg = (
2473+
f"Column '{col}' has dtype {self._data[col].dtype}, "
2474+
f"cannot use method '{method}' with this dtype"
2475+
)
2476+
else:
2477+
error_msg = (
2478+
f"Cannot use method '{method}' with "
2479+
f"dtype {self._data[col].dtype}"
2480+
)
2481+
raise TypeError(error_msg)
24682482
if len(self) == 0:
24692483
return self
24702484

python/cudf/cudf/tests/test_dataframe.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10316,3 +10316,16 @@ def test_dataframe_reindex_with_index_names(index_data, name):
1031610316
expected = pdf.reindex(index_data)
1031710317

1031810318
assert_eq(actual, expected)
10319+
10320+
10321+
@pytest.mark.parametrize("attr", ["nlargest", "nsmallest"])
10322+
def test_dataframe_nlargest_nsmallest_str_error(attr):
10323+
gdf = cudf.DataFrame({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "d"]})
10324+
pdf = gdf.to_pandas()
10325+
10326+
assert_exceptions_equal(
10327+
getattr(gdf, attr),
10328+
getattr(pdf, attr),
10329+
([], {"n": 1, "columns": ["a", "b"]}),
10330+
([], {"n": 1, "columns": ["a", "b"]}),
10331+
)

python/cudf/cudf/tests/test_series.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,3 +2244,13 @@ def test_series_typecast_to_object():
22442244
assert new_series[0] == "1970-01-01 00:00:00.000000001"
22452245
new_series = actual.astype(np.dtype("object"))
22462246
assert new_series[0] == "1970-01-01 00:00:00.000000001"
2247+
2248+
2249+
@pytest.mark.parametrize("attr", ["nlargest", "nsmallest"])
2250+
def test_series_nlargest_nsmallest_str_error(attr):
2251+
gs = cudf.Series(["a", "b", "c", "d", "e"])
2252+
ps = gs.to_pandas()
2253+
2254+
assert_exceptions_equal(
2255+
getattr(gs, attr), getattr(ps, attr), ([], {"n": 1}), ([], {"n": 1})
2256+
)

0 commit comments

Comments
 (0)