Skip to content

Commit b9f1bc6

Browse files
Backport PR #60195 on branch 2.3.x (BUG (string dtype): fix where() for string dtype with python storage) (#60202)
Backport PR #60195: BUG (string dtype): fix where() for string dtype with python storage Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent 70e8a3b commit b9f1bc6

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

pandas/core/arrays/string_.py

+6
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,12 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
757757
# base class implementation that uses __setitem__
758758
ExtensionArray._putmask(self, mask, value)
759759

760+
def _where(self, mask: npt.NDArray[np.bool_], value) -> Self:
761+
# the super() method NDArrayBackedExtensionArray._where uses
762+
# np.putmask which doesn't properly handle None/pd.NA, so using the
763+
# base class implementation that uses __setitem__
764+
return ExtensionArray._where(self, mask, value)
765+
760766
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
761767
if isinstance(values, BaseStringArray) or (
762768
isinstance(values, ExtensionArray) and is_string_dtype(values.dtype)

pandas/tests/frame/indexing/test_where.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from pandas._config import using_string_dtype
88

9-
from pandas.compat import HAS_PYARROW
10-
119
from pandas.core.dtypes.common import is_scalar
1210

1311
import pandas as pd
@@ -985,9 +983,6 @@ def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype):
985983
obj.mask(mask, null)
986984

987985

988-
@pytest.mark.xfail(
989-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
990-
)
991986
@given(data=OPTIONAL_ONE_OF_ALL)
992987
def test_where_inplace_casting(data):
993988
# GH 22051
@@ -1084,19 +1079,18 @@ def test_where_producing_ea_cond_for_np_dtype():
10841079
tm.assert_frame_equal(result, expected)
10851080

10861081

1087-
@pytest.mark.xfail(
1088-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)", strict=False
1089-
)
10901082
@pytest.mark.parametrize(
10911083
"replacement", [0.001, True, "snake", None, datetime(2022, 5, 4)]
10921084
)
1093-
def test_where_int_overflow(replacement, using_infer_string, request):
1085+
def test_where_int_overflow(replacement, using_infer_string):
10941086
# GH 31687
10951087
df = DataFrame([[1.0, 2e25, "nine"], [np.nan, 0.1, None]])
10961088
if using_infer_string and replacement not in (None, "snake"):
1097-
request.node.add_marker(
1098-
pytest.mark.xfail(reason="Can't set non-string into string column")
1099-
)
1089+
with pytest.raises(
1090+
TypeError, match="Cannot set non-string value|Scalar must be NA or str"
1091+
):
1092+
df.where(pd.notnull(df), replacement)
1093+
return
11001094
result = df.where(pd.notnull(df), replacement)
11011095
expected = DataFrame([[1.0, 2e25, "nine"], [replacement, 0.1, replacement]])
11021096

0 commit comments

Comments
 (0)