Skip to content

Commit 2e006e7

Browse files
TST/BUG (string dtype): Fix and adjust indexes string tests (#59544)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent 62b2478 commit 2e006e7

File tree

5 files changed

+24
-30
lines changed

5 files changed

+24
-30
lines changed

pandas/core/construction.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,10 @@ def sanitize_array(
609609
dtype = StringDtype(na_value=np.nan)
610610
subarr = dtype.construct_array_type()._from_sequence(data, dtype=dtype)
611611

612-
if subarr is data and copy:
612+
if (
613+
subarr is data
614+
or (subarr.dtype == "str" and subarr.dtype.storage == "python") # type: ignore[union-attr]
615+
) and copy:
613616
subarr = subarr.copy()
614617

615618
else:

pandas/core/indexes/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,8 @@ def __new__(
506506

507507
elif is_ea_or_datetimelike_dtype(dtype):
508508
# non-EA dtype indexes have special casting logic, so we punt here
509-
pass
509+
if isinstance(data, (set, frozenset)):
510+
data = list(data)
510511

511512
elif is_ea_or_datetimelike_dtype(data_dtype):
512513
pass
@@ -6995,6 +6996,9 @@ def insert(self, loc: int, item) -> Index:
69956996
# We cannot keep the same dtype, so cast to the (often object)
69966997
# minimal shared dtype before doing the insert.
69976998
dtype = self._find_common_type_compat(item)
6999+
if dtype == self.dtype:
7000+
# EA's might run into recursion errors if loc is invalid
7001+
raise
69987002
return self.astype(dtype).insert(loc, item)
69997003

70007004
if arr.dtype != object or not isinstance(

pandas/tests/indexes/base_class/test_setops.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
import pandas as pd
97
from pandas import (
108
Index,
@@ -233,7 +231,6 @@ def test_tuple_union_bug(self, method, expected, sort):
233231
expected = Index(expected)
234232
tm.assert_index_equal(result, expected)
235233

236-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
237234
@pytest.mark.parametrize("first_list", [["b", "a"], []])
238235
@pytest.mark.parametrize("second_list", [["a", "b"], []])
239236
@pytest.mark.parametrize(
@@ -243,6 +240,7 @@ def test_tuple_union_bug(self, method, expected, sort):
243240
def test_union_name_preservation(
244241
self, first_list, second_list, first_name, second_name, expected_name, sort
245242
):
243+
expected_dtype = object if not first_list or not second_list else "str"
246244
first = Index(first_list, name=first_name)
247245
second = Index(second_list, name=second_name)
248246
union = first.union(second, sort=sort)
@@ -253,7 +251,7 @@ def test_union_name_preservation(
253251
expected = Index(sorted(vals), name=expected_name)
254252
tm.assert_index_equal(union, expected)
255253
else:
256-
expected = Index(vals, name=expected_name)
254+
expected = Index(vals, name=expected_name, dtype=expected_dtype)
257255
tm.assert_index_equal(union.sort_values(), expected.sort_values())
258256

259257
@pytest.mark.parametrize(

pandas/tests/indexes/test_base.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@ def test_constructor_casting(self, index):
7676
tm.assert_contains_all(arr, new_index)
7777
tm.assert_index_equal(index, new_index)
7878

79-
@pytest.mark.xfail(
80-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
81-
)
8279
def test_constructor_copy(self, using_infer_string):
8380
index = Index(list("abc"), name="name")
8481
arr = np.array(index)
@@ -346,11 +343,6 @@ def test_constructor_empty_special(self, empty, klass):
346343
def test_view_with_args(self, index):
347344
index.view("i8")
348345

349-
@pytest.mark.xfail(
350-
using_string_dtype() and not HAS_PYARROW,
351-
reason="TODO(infer_string)",
352-
strict=False,
353-
)
354346
@pytest.mark.parametrize(
355347
"index",
356348
[
@@ -367,7 +359,8 @@ def test_view_with_args_object_array_raises(self, index):
367359
msg = "When changing to a larger dtype"
368360
with pytest.raises(ValueError, match=msg):
369361
index.view("i8")
370-
elif index.dtype == "string":
362+
elif index.dtype == "str" and not index.dtype.storage == "python":
363+
# TODO(infer_string): Make the errors consistent
371364
with pytest.raises(NotImplementedError, match="i8"):
372365
index.view("i8")
373366
else:

pandas/tests/indexes/test_old_base.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
import numpy as np
77
import pytest
88

9-
from pandas._config import using_string_dtype
10-
119
from pandas._libs.tslibs import Timestamp
12-
from pandas.compat import HAS_PYARROW
1310

1411
from pandas.core.dtypes.common import (
1512
is_integer_dtype,
@@ -28,6 +25,7 @@
2825
PeriodIndex,
2926
RangeIndex,
3027
Series,
28+
StringDtype,
3129
TimedeltaIndex,
3230
isna,
3331
period_range,
@@ -233,7 +231,6 @@ def test_logical_compat(self, simple_index):
233231
with pytest.raises(TypeError, match=msg):
234232
idx.any()
235233

236-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
237234
def test_repr_roundtrip(self, simple_index):
238235
if isinstance(simple_index, IntervalIndex):
239236
pytest.skip(f"Not a valid repr for {type(simple_index).__name__}")
@@ -250,11 +247,6 @@ def test_repr_max_seq_item_setting(self, simple_index):
250247
repr(idx)
251248
assert "..." not in str(idx)
252249

253-
@pytest.mark.xfail(
254-
using_string_dtype() and not HAS_PYARROW,
255-
reason="TODO(infer_string)",
256-
strict=False,
257-
)
258250
@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning")
259251
def test_ensure_copied_data(self, index):
260252
# Check the "copy" argument of each Index.__new__ is honoured
@@ -302,7 +294,9 @@ def test_ensure_copied_data(self, index):
302294
tm.assert_numpy_array_equal(
303295
index._values._mask, result._values._mask, check_same="same"
304296
)
305-
elif index.dtype == "string[python]":
297+
elif (
298+
isinstance(index.dtype, StringDtype) and index.dtype.storage == "python"
299+
):
306300
assert np.shares_memory(index._values._ndarray, result._values._ndarray)
307301
tm.assert_numpy_array_equal(
308302
index._values._ndarray, result._values._ndarray, check_same="same"
@@ -432,11 +426,7 @@ def test_insert_base(self, index):
432426
result = trimmed.insert(0, index[0])
433427
assert index[0:4].equals(result)
434428

435-
@pytest.mark.skipif(
436-
using_string_dtype(),
437-
reason="completely different behavior, tested elsewher",
438-
)
439-
def test_insert_out_of_bounds(self, index):
429+
def test_insert_out_of_bounds(self, index, using_infer_string):
440430
# TypeError/IndexError matches what np.insert raises in these cases
441431

442432
if len(index) > 0:
@@ -448,6 +438,12 @@ def test_insert_out_of_bounds(self, index):
448438
msg = "index (0|0.5) is out of bounds for axis 0 with size 0"
449439
else:
450440
msg = "slice indices must be integers or None or have an __index__ method"
441+
442+
if using_infer_string and (
443+
index.dtype == "string" or index.dtype == "category" # noqa: PLR1714
444+
):
445+
msg = "loc must be an integer between"
446+
451447
with pytest.raises(err, match=msg):
452448
index.insert(0.5, "foo")
453449

0 commit comments

Comments
 (0)