Skip to content

Commit 27ec887

Browse files
authored
BUG: numba raises for string columns or index (#56189)
1 parent 82c591d commit 27ec887

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

doc/source/whatsnew/v2.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ Conversion
496496
Strings
497497
^^^^^^^
498498
- Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`)
499+
- Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`)
499500
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`)
500-
-
501501

502502
Interval
503503
^^^^^^^^

pandas/core/apply.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1172,11 +1172,17 @@ def apply_with_numba(self) -> dict[int, Any]:
11721172
)
11731173
from pandas.core._numba.extensions import set_numba_data
11741174

1175+
index = self.obj.index
1176+
if index.dtype == "string":
1177+
index = index.astype(object)
1178+
1179+
columns = self.obj.columns
1180+
if columns.dtype == "string":
1181+
columns = columns.astype(object)
1182+
11751183
# Convert from numba dict to regular dict
11761184
# Our isinstance checks in the df constructor don't pass for numbas typed dict
1177-
with set_numba_data(self.obj.index) as index, set_numba_data(
1178-
self.columns
1179-
) as columns:
1185+
with set_numba_data(index) as index, set_numba_data(columns) as columns:
11801186
res = dict(nb_func(self.values, columns, index))
11811187
return res
11821188

pandas/tests/apply/test_numba.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@ def test_numba_vs_python_noop(float_frame, apply_axis):
2424
tm.assert_frame_equal(result, expected)
2525

2626

27+
def test_numba_vs_python_string_index():
28+
# GH#56189
29+
pytest.importorskip("pyarrow")
30+
df = DataFrame(
31+
1,
32+
index=Index(["a", "b"], dtype="string[pyarrow_numpy]"),
33+
columns=Index(["x", "y"], dtype="string[pyarrow_numpy]"),
34+
)
35+
func = lambda x: x
36+
result = df.apply(func, engine="numba", axis=0)
37+
expected = df.apply(func, engine="python", axis=0)
38+
tm.assert_frame_equal(
39+
result, expected, check_column_type=False, check_index_type=False
40+
)
41+
42+
2743
def test_numba_vs_python_indexing():
2844
frame = DataFrame(
2945
{"a": [1, 2, 3], "b": [4, 5, 6], "c": [7.0, 8.0, 9.0]},
@@ -88,7 +104,8 @@ def test_numba_unsupported_dtypes(apply_axis):
88104
df["c"] = df["c"].astype("double[pyarrow]")
89105

90106
with pytest.raises(
91-
ValueError, match="Column b must have a numeric dtype. Found 'object' instead"
107+
ValueError,
108+
match="Column b must have a numeric dtype. Found 'object|string' instead",
92109
):
93110
df.apply(f, engine="numba", axis=apply_axis)
94111

0 commit comments

Comments
 (0)