Skip to content

Commit 2b37c98

Browse files
[backport 2.3.x] String dtype: use ObjectEngine for indexing for now correctness over performance (#60329) (#60453)
String dtype: use ObjectEngine for indexing for now correctness over performance (#60329) (cherry picked from commit 98f7e4d)
1 parent 3bcbf0c commit 2b37c98

File tree

4 files changed

+124
-12
lines changed

4 files changed

+124
-12
lines changed

pandas/_libs/index.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class MaskedUInt16Engine(MaskedIndexEngine): ...
6868
class MaskedUInt8Engine(MaskedIndexEngine): ...
6969
class MaskedBoolEngine(MaskedUInt8Engine): ...
7070

71+
class StringObjectEngine(ObjectEngine):
72+
def __init__(self, values: object, na_value) -> None: ...
73+
7174
class BaseMultiIndexCodesEngine:
7275
levels: list[np.ndarray]
7376
offsets: np.ndarray # ndarray[uint64_t, ndim=1]

pandas/_libs/index.pyx

+26
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,32 @@ cdef class ObjectEngine(IndexEngine):
532532
return loc
533533

534534

535+
cdef class StringObjectEngine(ObjectEngine):
536+
537+
cdef:
538+
object na_value
539+
bint uses_na
540+
541+
def __init__(self, ndarray values, na_value):
542+
super().__init__(values)
543+
self.na_value = na_value
544+
self.uses_na = na_value is C_NA
545+
546+
cdef bint _checknull(self, object val):
547+
if self.uses_na:
548+
return val is C_NA
549+
else:
550+
return util.is_nan(val)
551+
552+
cdef _check_type(self, object val):
553+
if isinstance(val, str):
554+
return val
555+
elif self._checknull(val):
556+
return self.na_value
557+
else:
558+
raise KeyError(val)
559+
560+
535561
cdef class DatetimeEngine(Int64Engine):
536562

537563
cdef:

pandas/core/indexes/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,8 @@ def _engine(
884884
# error: Item "ExtensionArray" of "Union[ExtensionArray,
885885
# ndarray[Any, Any]]" has no attribute "_ndarray" [union-attr]
886886
target_values = self._data._ndarray # type: ignore[union-attr]
887+
elif is_string_dtype(self.dtype) and not is_object_dtype(self.dtype):
888+
return libindex.StringObjectEngine(target_values, self.dtype.na_value) # type: ignore[union-attr]
887889

888890
# error: Argument 1 to "ExtensionEngine" has incompatible type
889891
# "ndarray[Any, Any]"; expected "ExtensionArray"
@@ -6133,7 +6135,6 @@ def _should_fallback_to_positional(self) -> bool:
61336135
def get_indexer_non_unique(
61346136
self, target
61356137
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
6136-
target = ensure_index(target)
61376138
target = self._maybe_cast_listlike_indexer(target)
61386139

61396140
if not self._should_compare(target) and not self._should_partial_index(target):

pandas/tests/indexes/string/test_indexing.py

+93-11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,51 @@
66
import pandas._testing as tm
77

88

9+
def _isnan(val):
10+
try:
11+
return val is not pd.NA and np.isnan(val)
12+
except TypeError:
13+
return False
14+
15+
16+
class TestGetLoc:
17+
def test_get_loc(self, any_string_dtype):
18+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
19+
assert index.get_loc("b") == 1
20+
21+
def test_get_loc_raises(self, any_string_dtype):
22+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
23+
with pytest.raises(KeyError, match="d"):
24+
index.get_loc("d")
25+
26+
def test_get_loc_invalid_value(self, any_string_dtype):
27+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
28+
with pytest.raises(KeyError, match="1"):
29+
index.get_loc(1)
30+
31+
def test_get_loc_non_unique(self, any_string_dtype):
32+
index = Index(["a", "b", "a"], dtype=any_string_dtype)
33+
result = index.get_loc("a")
34+
expected = np.array([True, False, True])
35+
tm.assert_numpy_array_equal(result, expected)
36+
37+
def test_get_loc_non_missing(self, any_string_dtype, nulls_fixture):
38+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
39+
with pytest.raises(KeyError):
40+
index.get_loc(nulls_fixture)
41+
42+
def test_get_loc_missing(self, any_string_dtype, nulls_fixture):
43+
index = Index(["a", "b", nulls_fixture], dtype=any_string_dtype)
44+
if any_string_dtype == "string" and (
45+
(any_string_dtype.na_value is pd.NA and nulls_fixture is not pd.NA)
46+
or (_isnan(any_string_dtype.na_value) and not _isnan(nulls_fixture))
47+
):
48+
with pytest.raises(KeyError):
49+
index.get_loc(nulls_fixture)
50+
else:
51+
assert index.get_loc(nulls_fixture) == 2
52+
53+
954
class TestGetIndexer:
1055
@pytest.mark.parametrize(
1156
"method,expected",
@@ -41,23 +86,60 @@ def test_get_indexer_strings_raises(self, any_string_dtype):
4186
["a", "b", "c", "d"], method="pad", tolerance=[2, 2, 2, 2]
4287
)
4388

89+
@pytest.mark.parametrize("null", [None, np.nan, float("nan"), pd.NA])
90+
def test_get_indexer_missing(self, any_string_dtype, null, using_infer_string):
91+
# NaT and Decimal("NaN") from null_fixture are not supported for string dtype
92+
index = Index(["a", "b", null], dtype=any_string_dtype)
93+
result = index.get_indexer(["a", null, "c"])
94+
if using_infer_string:
95+
expected = np.array([0, 2, -1], dtype=np.intp)
96+
elif any_string_dtype == "string" and (
97+
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
98+
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
99+
):
100+
expected = np.array([0, -1, -1], dtype=np.intp)
101+
else:
102+
expected = np.array([0, 2, -1], dtype=np.intp)
44103

45-
class TestGetIndexerNonUnique:
46-
@pytest.mark.xfail(reason="TODO(infer_string)", strict=False)
47-
def test_get_indexer_non_unique_nas(self, any_string_dtype, nulls_fixture):
48-
index = Index(["a", "b", None], dtype=any_string_dtype)
49-
indexer, missing = index.get_indexer_non_unique([nulls_fixture])
104+
tm.assert_numpy_array_equal(result, expected)
50105

51-
expected_indexer = np.array([2], dtype=np.intp)
52-
expected_missing = np.array([], dtype=np.intp)
106+
107+
class TestGetIndexerNonUnique:
108+
@pytest.mark.parametrize("null", [None, np.nan, float("nan"), pd.NA])
109+
def test_get_indexer_non_unique_nas(
110+
self, any_string_dtype, null, using_infer_string
111+
):
112+
index = Index(["a", "b", null], dtype=any_string_dtype)
113+
indexer, missing = index.get_indexer_non_unique(["a", null])
114+
115+
if using_infer_string:
116+
expected_indexer = np.array([0, 2], dtype=np.intp)
117+
expected_missing = np.array([], dtype=np.intp)
118+
elif any_string_dtype == "string" and (
119+
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
120+
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
121+
):
122+
expected_indexer = np.array([0, -1], dtype=np.intp)
123+
expected_missing = np.array([1], dtype=np.intp)
124+
else:
125+
expected_indexer = np.array([0, 2], dtype=np.intp)
126+
expected_missing = np.array([], dtype=np.intp)
53127
tm.assert_numpy_array_equal(indexer, expected_indexer)
54128
tm.assert_numpy_array_equal(missing, expected_missing)
55129

56130
# actually non-unique
57-
index = Index(["a", None, "b", None], dtype=any_string_dtype)
58-
indexer, missing = index.get_indexer_non_unique([nulls_fixture])
59-
60-
expected_indexer = np.array([1, 3], dtype=np.intp)
131+
index = Index(["a", null, "b", null], dtype=any_string_dtype)
132+
indexer, missing = index.get_indexer_non_unique(["a", null])
133+
134+
if using_infer_string:
135+
expected_indexer = np.array([0, 1, 3], dtype=np.intp)
136+
elif any_string_dtype == "string" and (
137+
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
138+
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
139+
):
140+
pass
141+
else:
142+
expected_indexer = np.array([0, 1, 3], dtype=np.intp)
61143
tm.assert_numpy_array_equal(indexer, expected_indexer)
62144
tm.assert_numpy_array_equal(missing, expected_missing)
63145

0 commit comments

Comments
 (0)