Skip to content

Commit 4a9c46b

Browse files
String dtype: allow string dtype for non-raw apply with numba engine (pandas-dev#59854)
* String dtype: allow string dtype for non-raw apply with numba engine * remove xfails * clean-up
1 parent dc4399c commit 4a9c46b

File tree

4 files changed

+2
-11
lines changed

4 files changed

+2
-11
lines changed

pandas/core/_numba/extensions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
@contextmanager
5050
def set_numba_data(index: Index):
5151
numba_data = index._data
52-
if numba_data.dtype == object:
52+
if numba_data.dtype in (object, "string"):
53+
numba_data = np.asarray(numba_data)
5354
if not lib.is_string_array(numba_data):
5455
raise ValueError(
5556
"The numba engine only supports using string or numeric column names"

pandas/core/apply.py

-5
Original file line numberDiff line numberDiff line change
@@ -1174,12 +1174,7 @@ def apply_with_numba(self) -> dict[int, Any]:
11741174
from pandas.core._numba.extensions import set_numba_data
11751175

11761176
index = self.obj.index
1177-
if index.dtype == "string":
1178-
index = index.astype(object)
1179-
11801177
columns = self.obj.columns
1181-
if columns.dtype == "string":
1182-
columns = columns.astype(object)
11831178

11841179
# Convert from numba dict to regular dict
11851180
# Our isinstance checks in the df constructor don't pass for numbas typed dict

pandas/tests/apply/test_frame_apply.py

-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def test_apply(float_frame, engine, request):
6565
assert result.index is float_frame.index
6666

6767

68-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
6968
@pytest.mark.parametrize("axis", [0, 1])
7069
@pytest.mark.parametrize("raw", [True, False])
7170
def test_apply_args(float_frame, axis, raw, engine, request):

pandas/tests/apply/test_numba.py

-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
from pandas._config import using_string_dtype
5-
64
import pandas.util._test_decorators as td
75

86
import pandas as pd
@@ -20,7 +18,6 @@ def apply_axis(request):
2018
return request.param
2119

2220

23-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
2421
def test_numba_vs_python_noop(float_frame, apply_axis):
2522
func = lambda x: x
2623
result = float_frame.apply(func, engine="numba", axis=apply_axis)
@@ -43,7 +40,6 @@ def test_numba_vs_python_string_index():
4340
)
4441

4542

46-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
4743
def test_numba_vs_python_indexing():
4844
frame = DataFrame(
4945
{"a": [1, 2, 3], "b": [4, 5, 6], "c": [7.0, 8.0, 9.0]},

0 commit comments

Comments
 (0)