Skip to content

Commit 633ff18

Browse files
authored
BUG: Series.where with incompatible NA value (#44697)
1 parent 878a022 commit 633ff18

File tree

7 files changed

+85
-15
lines changed

7 files changed

+85
-15
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,7 @@ ExtensionArray
803803
- Avoid raising ``PerformanceWarning`` about fragmented DataFrame when using many columns with an extension dtype (:issue:`44098`)
804804
- Bug in :class:`IntegerArray` and :class:`FloatingArray` construction incorrectly coercing mismatched NA values (e.g. ``np.timedelta64("NaT")``) to numeric NA (:issue:`44514`)
805805
- Bug in :meth:`BooleanArray.__eq__` and :meth:`BooleanArray.__ne__` raising ``TypeError`` on comparison with an incompatible type (like a string). This caused :meth:`DataFrame.replace` to sometimes raise a ``TypeError`` if a nullable boolean column was included (:issue:`44499`)
806+
- Bug in :meth:`Series.where` with ``ExtensionDtype`` when ``other`` is a NA scalar incompatible with the series dtype (e.g. ``NaT`` with a numeric dtype) incorrectly casting to a compatible NA value (:issue:`44697`)
806807
-
807808

808809
Styler

pandas/core/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10848,7 +10848,7 @@ def interpolate(
1084810848
def where(
1084910849
self,
1085010850
cond,
10851-
other=np.nan,
10851+
other=lib.no_default,
1085210852
inplace=False,
1085310853
axis=None,
1085410854
level=None,

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8937,7 +8937,7 @@ def _align_series(
89378937
def _where(
89388938
self,
89398939
cond,
8940-
other=np.nan,
8940+
other=lib.no_default,
89418941
inplace=False,
89428942
axis=None,
89438943
level=None,

pandas/core/internals/blocks.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,9 @@ def putmask(self, mask, new) -> list[Block]:
962962
mask, noop = validate_putmask(values.T, mask)
963963
assert not isinstance(new, (ABCIndex, ABCSeries, ABCDataFrame))
964964

965+
if new is lib.no_default:
966+
new = self.fill_value
967+
965968
# if we are passed a scalar None, convert it here
966969
if not self.is_object and is_valid_na_for_dtype(new, self.dtype):
967970
new = self.fill_value
@@ -1173,6 +1176,9 @@ def where(self, other, cond) -> list[Block]:
11731176

11741177
icond, noop = validate_putmask(values, ~cond)
11751178

1179+
if other is lib.no_default:
1180+
other = self.fill_value
1181+
11761182
if is_valid_na_for_dtype(other, self.dtype) and self.dtype != _dtype_obj:
11771183
other = self.fill_value
11781184

@@ -1640,13 +1646,8 @@ def where(self, other, cond) -> list[Block]:
16401646
other = self._maybe_squeeze_arg(other)
16411647
cond = self._maybe_squeeze_arg(cond)
16421648

1643-
if lib.is_scalar(other) and isna(other):
1644-
# The default `other` for Series / Frame is np.nan
1645-
# we want to replace that with the correct NA value
1646-
# for the type
1647-
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" has no
1648-
# attribute "na_value"
1649-
other = self.dtype.na_value # type: ignore[union-attr]
1649+
if other is lib.no_default:
1650+
other = self.fill_value
16501651

16511652
icond, noop = validate_putmask(self.values, ~cond)
16521653
if noop:
@@ -1741,6 +1742,8 @@ def where(self, other, cond) -> list[Block]:
17411742
arr = self.values
17421743

17431744
cond = extract_bool_array(cond)
1745+
if other is lib.no_default:
1746+
other = self.fill_value
17441747

17451748
try:
17461749
res_values = arr.T._where(cond, other).T

pandas/core/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5510,7 +5510,7 @@ def interpolate(
55105510
def where(
55115511
self,
55125512
cond,
5513-
other=np.nan,
5513+
other=lib.no_default,
55145514
inplace=False,
55155515
axis=None,
55165516
level=None,

pandas/tests/arrays/integer/test_construction.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,13 @@ def test_to_integer_array_none_is_nan(a, b):
133133
)
134134
def test_to_integer_array_error(values):
135135
# error in converting existing arrays to IntegerArrays
136-
msg = (
137-
r"(:?.* cannot be converted to an IntegerDtype)"
138-
r"|(invalid literal for int\(\) with base 10: .*)"
139-
r"|(:?values must be a 1D list-like)"
140-
r"|(Cannot pass scalar)"
136+
msg = "|".join(
137+
[
138+
r"cannot be converted to an IntegerDtype",
139+
r"invalid literal for int\(\) with base 10:",
140+
r"values must be a 1D list-like",
141+
r"Cannot pass scalar",
142+
]
141143
)
142144
with pytest.raises((ValueError, TypeError), match=msg):
143145
pd.array(values, dtype="Int64")

pandas/tests/frame/indexing/test_where.py

+64
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import numpy as np
88
import pytest
99

10+
from pandas.compat import np_version_under1p19
11+
1012
from pandas.core.dtypes.common import is_scalar
1113

1214
import pandas as pd
@@ -810,6 +812,68 @@ def test_where_columns_casting():
810812
tm.assert_frame_equal(expected, result)
811813

812814

815+
@pytest.mark.parametrize("as_cat", [True, False])
816+
def test_where_period_invalid_na(frame_or_series, as_cat, request):
817+
# GH#44697
818+
idx = pd.period_range("2016-01-01", periods=3, freq="D")
819+
if as_cat:
820+
idx = idx.astype("category")
821+
obj = frame_or_series(idx)
822+
823+
# NA value that we should *not* cast to Period dtype
824+
tdnat = pd.NaT.to_numpy("m8[ns]")
825+
826+
mask = np.array([True, True, False], ndmin=obj.ndim).T
827+
828+
if as_cat:
829+
msg = (
830+
r"Cannot setitem on a Categorical with a new category \(NaT\), "
831+
"set the categories first"
832+
)
833+
if np_version_under1p19:
834+
mark = pytest.mark.xfail(
835+
reason="When evaluating the f-string to generate the exception "
836+
"message, numpy somehow ends up trying to cast None to int, so "
837+
"ends up raising TypeError but with an unrelated message."
838+
)
839+
request.node.add_marker(mark)
840+
else:
841+
msg = "value should be a 'Period'"
842+
843+
with pytest.raises(TypeError, match=msg):
844+
obj.where(mask, tdnat)
845+
846+
with pytest.raises(TypeError, match=msg):
847+
obj.mask(mask, tdnat)
848+
849+
850+
def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype):
851+
# GH#44697
852+
arr = pd.array([1, 2, 3], dtype=any_numeric_ea_dtype)
853+
obj = frame_or_series(arr)
854+
855+
mask = np.array([True, True, False], ndmin=obj.ndim).T
856+
857+
msg = "|".join(
858+
[
859+
r"datetime64\[.{1,2}\] cannot be converted to an? (Integer|Floating)Dtype",
860+
r"timedelta64\[.{1,2}\] cannot be converted to an? (Integer|Floating)Dtype",
861+
r"int\(\) argument must be a string, a bytes-like object or a number, "
862+
"not 'NaTType'",
863+
"object cannot be converted to a FloatingDtype",
864+
"'values' contains non-numeric NA",
865+
]
866+
)
867+
868+
for null in tm.NP_NAT_OBJECTS + [pd.NaT]:
869+
# NaT is an NA value that we should *not* cast to pd.NA dtype
870+
with pytest.raises(TypeError, match=msg):
871+
obj.where(mask, null)
872+
873+
with pytest.raises(TypeError, match=msg):
874+
obj.mask(mask, null)
875+
876+
813877
@given(
814878
data=st.one_of(
815879
OPTIONAL_DICTS, OPTIONAL_FLOATS, OPTIONAL_INTS, OPTIONAL_LISTS, OPTIONAL_TEXT

0 commit comments

Comments
 (0)