Skip to content

Commit 98f7e4d

Browse files
String dtype: use ObjectEngine for indexing for now correctness over performance (#60329)
1 parent fd570f4 commit 98f7e4d

File tree

5 files changed

+124
-14
lines changed

5 files changed

+124
-14
lines changed

pandas/_libs/index.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ class MaskedUInt16Engine(MaskedIndexEngine): ...
7272
class MaskedUInt8Engine(MaskedIndexEngine): ...
7373
class MaskedBoolEngine(MaskedUInt8Engine): ...
7474

75+
class StringObjectEngine(ObjectEngine):
76+
def __init__(self, values: object, na_value) -> None: ...
77+
7578
class BaseMultiIndexCodesEngine:
7679
levels: list[np.ndarray]
7780
offsets: np.ndarray # np.ndarray[..., ndim=1]

pandas/_libs/index.pyx

+25
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,31 @@ cdef class StringEngine(IndexEngine):
557557
raise KeyError(val)
558558
return str(val)
559559

560+
cdef class StringObjectEngine(ObjectEngine):
561+
562+
cdef:
563+
object na_value
564+
bint uses_na
565+
566+
def __init__(self, ndarray values, na_value):
567+
super().__init__(values)
568+
self.na_value = na_value
569+
self.uses_na = na_value is C_NA
570+
571+
cdef bint _checknull(self, object val):
572+
if self.uses_na:
573+
return val is C_NA
574+
else:
575+
return util.is_nan(val)
576+
577+
cdef _check_type(self, object val):
578+
if isinstance(val, str):
579+
return val
580+
elif self._checknull(val):
581+
return self.na_value
582+
else:
583+
raise KeyError(val)
584+
560585

561586
cdef class DatetimeEngine(Int64Engine):
562587

pandas/core/indexes/base.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def _engine(
876876
# ndarray[Any, Any]]" has no attribute "_ndarray" [union-attr]
877877
target_values = self._data._ndarray # type: ignore[union-attr]
878878
elif is_string_dtype(self.dtype) and not is_object_dtype(self.dtype):
879-
return libindex.StringEngine(target_values)
879+
return libindex.StringObjectEngine(target_values, self.dtype.na_value) # type: ignore[union-attr]
880880

881881
# error: Argument 1 to "ExtensionEngine" has incompatible type
882882
# "ndarray[Any, Any]"; expected "ExtensionArray"
@@ -5974,7 +5974,6 @@ def _should_fallback_to_positional(self) -> bool:
59745974
def get_indexer_non_unique(
59755975
self, target
59765976
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
5977-
target = ensure_index(target)
59785977
target = self._maybe_cast_listlike_indexer(target)
59795978

59805979
if not self._should_compare(target) and not self._should_partial_index(target):

pandas/tests/indexes/string/test_indexing.py

+93-11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,51 @@
66
import pandas._testing as tm
77

88

9+
def _isnan(val):
10+
try:
11+
return val is not pd.NA and np.isnan(val)
12+
except TypeError:
13+
return False
14+
15+
16+
class TestGetLoc:
17+
def test_get_loc(self, any_string_dtype):
18+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
19+
assert index.get_loc("b") == 1
20+
21+
def test_get_loc_raises(self, any_string_dtype):
22+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
23+
with pytest.raises(KeyError, match="d"):
24+
index.get_loc("d")
25+
26+
def test_get_loc_invalid_value(self, any_string_dtype):
27+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
28+
with pytest.raises(KeyError, match="1"):
29+
index.get_loc(1)
30+
31+
def test_get_loc_non_unique(self, any_string_dtype):
32+
index = Index(["a", "b", "a"], dtype=any_string_dtype)
33+
result = index.get_loc("a")
34+
expected = np.array([True, False, True])
35+
tm.assert_numpy_array_equal(result, expected)
36+
37+
def test_get_loc_non_missing(self, any_string_dtype, nulls_fixture):
38+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
39+
with pytest.raises(KeyError):
40+
index.get_loc(nulls_fixture)
41+
42+
def test_get_loc_missing(self, any_string_dtype, nulls_fixture):
43+
index = Index(["a", "b", nulls_fixture], dtype=any_string_dtype)
44+
if any_string_dtype == "string" and (
45+
(any_string_dtype.na_value is pd.NA and nulls_fixture is not pd.NA)
46+
or (_isnan(any_string_dtype.na_value) and not _isnan(nulls_fixture))
47+
):
48+
with pytest.raises(KeyError):
49+
index.get_loc(nulls_fixture)
50+
else:
51+
assert index.get_loc(nulls_fixture) == 2
52+
53+
954
class TestGetIndexer:
1055
@pytest.mark.parametrize(
1156
"method,expected",
@@ -41,23 +86,60 @@ def test_get_indexer_strings_raises(self, any_string_dtype):
4186
["a", "b", "c", "d"], method="pad", tolerance=[2, 2, 2, 2]
4287
)
4388

89+
@pytest.mark.parametrize("null", [None, np.nan, float("nan"), pd.NA])
90+
def test_get_indexer_missing(self, any_string_dtype, null, using_infer_string):
91+
# NaT and Decimal("NaN") from null_fixture are not supported for string dtype
92+
index = Index(["a", "b", null], dtype=any_string_dtype)
93+
result = index.get_indexer(["a", null, "c"])
94+
if using_infer_string:
95+
expected = np.array([0, 2, -1], dtype=np.intp)
96+
elif any_string_dtype == "string" and (
97+
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
98+
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
99+
):
100+
expected = np.array([0, -1, -1], dtype=np.intp)
101+
else:
102+
expected = np.array([0, 2, -1], dtype=np.intp)
44103

45-
class TestGetIndexerNonUnique:
46-
@pytest.mark.xfail(reason="TODO(infer_string)", strict=False)
47-
def test_get_indexer_non_unique_nas(self, any_string_dtype, nulls_fixture):
48-
index = Index(["a", "b", None], dtype=any_string_dtype)
49-
indexer, missing = index.get_indexer_non_unique([nulls_fixture])
104+
tm.assert_numpy_array_equal(result, expected)
50105

51-
expected_indexer = np.array([2], dtype=np.intp)
52-
expected_missing = np.array([], dtype=np.intp)
106+
107+
class TestGetIndexerNonUnique:
108+
@pytest.mark.parametrize("null", [None, np.nan, float("nan"), pd.NA])
109+
def test_get_indexer_non_unique_nas(
110+
self, any_string_dtype, null, using_infer_string
111+
):
112+
index = Index(["a", "b", null], dtype=any_string_dtype)
113+
indexer, missing = index.get_indexer_non_unique(["a", null])
114+
115+
if using_infer_string:
116+
expected_indexer = np.array([0, 2], dtype=np.intp)
117+
expected_missing = np.array([], dtype=np.intp)
118+
elif any_string_dtype == "string" and (
119+
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
120+
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
121+
):
122+
expected_indexer = np.array([0, -1], dtype=np.intp)
123+
expected_missing = np.array([1], dtype=np.intp)
124+
else:
125+
expected_indexer = np.array([0, 2], dtype=np.intp)
126+
expected_missing = np.array([], dtype=np.intp)
53127
tm.assert_numpy_array_equal(indexer, expected_indexer)
54128
tm.assert_numpy_array_equal(missing, expected_missing)
55129

56130
# actually non-unique
57-
index = Index(["a", None, "b", None], dtype=any_string_dtype)
58-
indexer, missing = index.get_indexer_non_unique([nulls_fixture])
59-
60-
expected_indexer = np.array([1, 3], dtype=np.intp)
131+
index = Index(["a", null, "b", null], dtype=any_string_dtype)
132+
indexer, missing = index.get_indexer_non_unique(["a", null])
133+
134+
if using_infer_string:
135+
expected_indexer = np.array([0, 1, 3], dtype=np.intp)
136+
elif any_string_dtype == "string" and (
137+
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
138+
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
139+
):
140+
pass
141+
else:
142+
expected_indexer = np.array([0, 1, 3], dtype=np.intp)
61143
tm.assert_numpy_array_equal(indexer, expected_indexer)
62144
tm.assert_numpy_array_equal(missing, expected_missing)
63145

pandas/tests/io/parser/common/test_common_basic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pandas._config import using_string_dtype
1717

18+
from pandas.compat import HAS_PYARROW
1819
from pandas.errors import (
1920
EmptyDataError,
2021
ParserError,
@@ -766,7 +767,7 @@ def test_dict_keys_as_names(all_parsers):
766767
tm.assert_frame_equal(result, expected)
767768

768769

769-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
770+
@pytest.mark.xfail(using_string_dtype() and HAS_PYARROW, reason="TODO(infer_string)")
770771
@xfail_pyarrow # UnicodeDecodeError: 'utf-8' codec can't decode byte 0xed in position 0
771772
def test_encoding_surrogatepass(all_parsers):
772773
# GH39017

0 commit comments

Comments
 (0)