Skip to content

Commit be8d1ec

Browse files
authored
BUG: Series[Period][mask] = 'foo' raising inconsistent with non-mask indexing (#45768)
1 parent 06dac44 commit be8d1ec

File tree

6 files changed

+89
-29
lines changed

6 files changed

+89
-29
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ Indexing
266266
- Bug in :meth:`DataFrame.iloc` where indexing a single row on a :class:`DataFrame` with a single ExtensionDtype column gave a copy instead of a view on the underlying data (:issue:`45241`)
267267
- Bug in setting a NA value (``None`` or ``np.nan``) into a :class:`Series` with int-based :class:`IntervalDtype` incorrectly casting to object dtype instead of a float-based :class:`IntervalDtype` (:issue:`45568`)
268268
- Bug in :meth:`Series.__setitem__` with a non-integer :class:`Index` when using an integer key to set a value that cannot be set inplace where a ``ValueError`` was raised insead of casting to a common dtype (:issue:`45070`)
269+
- Bug in :meth:`Series.__setitem__` when setting incompatible values into a ``PeriodDtype`` or ``IntervalDtype`` :class:`Series` raising when indexing with a boolean mask but coercing when indexing with otherwise-equivalent indexers; these now consistently coerce, along with :meth:`Series.mask` and :meth:`Series.where` (:issue:`45768`)
269270
- Bug in :meth:`Series.loc.__setitem__` and :meth:`Series.loc.__getitem__` not raising when using multiple keys without using a :class:`MultiIndex` (:issue:`13831`)
270271
- Bug when setting a value too large for a :class:`Series` dtype failing to coerce to a common type (:issue:`26049`, :issue:`32878`)
271272
- Bug in :meth:`loc.__setitem__` treating ``range`` keys as positional instead of label-based (:issue:`45479`)

pandas/core/internals/blocks.py

+13-22
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,8 @@ def where(self, other, cond) -> list[Block]:
13761376

13771377
cond = extract_bool_array(cond)
13781378

1379+
orig_other = other
1380+
orig_cond = cond
13791381
other = self._maybe_squeeze_arg(other)
13801382
cond = self._maybe_squeeze_arg(cond)
13811383

@@ -1395,21 +1397,15 @@ def where(self, other, cond) -> list[Block]:
13951397

13961398
if is_interval_dtype(self.dtype):
13971399
# TestSetitemFloatIntervalWithIntIntervalValues
1398-
blk = self.coerce_to_target_dtype(other)
1399-
if blk.dtype == _dtype_obj:
1400-
# For now at least only support casting e.g.
1401-
# Interval[int64]->Interval[float64]
1402-
raise
1403-
return blk.where(other, cond)
1400+
blk = self.coerce_to_target_dtype(orig_other)
1401+
nbs = blk.where(orig_other, orig_cond)
1402+
return self._maybe_downcast(nbs, "infer")
14041403

14051404
elif isinstance(self, NDArrayBackedExtensionBlock):
14061405
# NB: not (yet) the same as
14071406
# isinstance(values, NDArrayBackedExtensionArray)
1408-
if isinstance(self.dtype, PeriodDtype):
1409-
# TODO: don't special-case
1410-
raise
1411-
blk = self.coerce_to_target_dtype(other)
1412-
nbs = blk.where(other, cond)
1407+
blk = self.coerce_to_target_dtype(orig_other)
1408+
nbs = blk.where(orig_other, orig_cond)
14131409
return self._maybe_downcast(nbs, "infer")
14141410

14151411
else:
@@ -1426,6 +1422,8 @@ def putmask(self, mask, new) -> list[Block]:
14261422

14271423
values = self.values
14281424

1425+
orig_new = new
1426+
orig_mask = mask
14291427
new = self._maybe_squeeze_arg(new)
14301428
mask = self._maybe_squeeze_arg(mask)
14311429

@@ -1438,21 +1436,14 @@ def putmask(self, mask, new) -> list[Block]:
14381436
if is_interval_dtype(self.dtype):
14391437
# Discussion about what we want to support in the general
14401438
# case GH#39584
1441-
blk = self.coerce_to_target_dtype(new)
1442-
if blk.dtype == _dtype_obj:
1443-
# For now at least, only support casting e.g.
1444-
# Interval[int64]->Interval[float64],
1445-
raise
1446-
return blk.putmask(mask, new)
1439+
blk = self.coerce_to_target_dtype(orig_new)
1440+
return blk.putmask(orig_mask, orig_new)
14471441

14481442
elif isinstance(self, NDArrayBackedExtensionBlock):
14491443
# NB: not (yet) the same as
14501444
# isinstance(values, NDArrayBackedExtensionArray)
1451-
if isinstance(self.dtype, PeriodDtype):
1452-
# TODO: don't special-case
1453-
raise
1454-
blk = self.coerce_to_target_dtype(new)
1455-
return blk.putmask(mask, new)
1445+
blk = self.coerce_to_target_dtype(orig_new)
1446+
return blk.putmask(orig_mask, orig_new)
14561447

14571448
else:
14581449
raise

pandas/tests/arrays/interval/test_interval.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,16 @@ def test_set_closed(self, closed, new_closed):
7676
],
7777
)
7878
def test_where_raises(self, other):
79+
# GH#45768 The IntervalArray methods raises; the Series method coerces
7980
ser = pd.Series(IntervalArray.from_breaks([1, 2, 3, 4], closed="left"))
81+
mask = np.array([True, False, True])
8082
match = "'value.closed' is 'right', expected 'left'."
8183
with pytest.raises(ValueError, match=match):
82-
ser.where([True, False, True], other=other)
84+
ser.array._where(mask, other)
85+
86+
res = ser.where(mask, other=other)
87+
expected = ser.astype(object).where(mask, other)
88+
tm.assert_series_equal(res, expected)
8389

8490
def test_shift(self):
8591
# https://github.com/pandas-dev/pandas/issues/31495, GH#22428, GH#31502

pandas/tests/arrays/test_period.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,16 @@ def test_sub_period():
124124
[pd.Period("2000", freq="H"), period_array(["2000", "2001", "2000"], freq="H")],
125125
)
126126
def test_where_different_freq_raises(other):
127+
# GH#45768 The PeriodArray method raises, the Series method coerces
127128
ser = pd.Series(period_array(["2000", "2001", "2002"], freq="D"))
128129
cond = np.array([True, False, True])
130+
129131
with pytest.raises(IncompatibleFrequency, match="freq"):
130-
ser.where(cond, other)
132+
ser.array._where(cond, other)
133+
134+
res = ser.where(cond, other)
135+
expected = ser.astype(object).where(cond, other)
136+
tm.assert_series_equal(res, expected)
131137

132138

133139
# ----------------------------------------------------------------------------

pandas/tests/frame/indexing/test_where.py

+44-5
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,20 @@ def test_where_interval_noop(self):
706706
res = ser.where(ser.notna())
707707
tm.assert_series_equal(res, ser)
708708

709+
def test_where_interval_fullop_downcast(self, frame_or_series):
710+
# GH#45768
711+
obj = frame_or_series([pd.Interval(0, 0)] * 2)
712+
other = frame_or_series([1.0, 2.0])
713+
res = obj.where(~obj.notna(), other)
714+
715+
# since all entries are being changed, we will downcast result
716+
# from object to ints (not floats)
717+
tm.assert_equal(res, other.astype(np.int64))
718+
719+
# unlike where, Block.putmask does not downcast
720+
obj.mask(obj.notna(), other, inplace=True)
721+
tm.assert_equal(obj, other.astype(object))
722+
709723
@pytest.mark.parametrize(
710724
"dtype",
711725
[
@@ -736,6 +750,16 @@ def test_where_datetimelike_noop(self, dtype):
736750
res4 = df.mask(mask2, "foo")
737751
tm.assert_frame_equal(res4, df)
738752

753+
# opposite case where we are replacing *all* values -> we downcast
754+
# from object dtype # GH#45768
755+
res5 = df.where(mask2, 4)
756+
expected = DataFrame(4, index=df.index, columns=df.columns)
757+
tm.assert_frame_equal(res5, expected)
758+
759+
# unlike where, Block.putmask does not downcast
760+
df.mask(~mask2, 4, inplace=True)
761+
tm.assert_frame_equal(df, expected.astype(object))
762+
739763

740764
def test_where_try_cast_deprecated(frame_or_series):
741765
obj = DataFrame(np.random.randn(4, 3))
@@ -894,14 +918,29 @@ def test_where_period_invalid_na(frame_or_series, as_cat, request):
894918
else:
895919
msg = "value should be a 'Period'"
896920

897-
with pytest.raises(TypeError, match=msg):
898-
obj.where(mask, tdnat)
921+
if as_cat:
922+
with pytest.raises(TypeError, match=msg):
923+
obj.where(mask, tdnat)
899924

900-
with pytest.raises(TypeError, match=msg):
901-
obj.mask(mask, tdnat)
925+
with pytest.raises(TypeError, match=msg):
926+
obj.mask(mask, tdnat)
927+
928+
with pytest.raises(TypeError, match=msg):
929+
obj.mask(mask, tdnat, inplace=True)
930+
931+
else:
932+
# With PeriodDtype, ser[i] = tdnat coerces instead of raising,
933+
# so for consistency, ser[mask] = tdnat must as well
934+
expected = obj.astype(object).where(mask, tdnat)
935+
result = obj.where(mask, tdnat)
936+
tm.assert_equal(result, expected)
937+
938+
expected = obj.astype(object).mask(mask, tdnat)
939+
result = obj.mask(mask, tdnat)
940+
tm.assert_equal(result, expected)
902941

903-
with pytest.raises(TypeError, match=msg):
904942
obj.mask(mask, tdnat, inplace=True)
943+
tm.assert_equal(obj, expected)
905944

906945

907946
def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype):

pandas/tests/series/indexing/test_setitem.py

+17
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
IntervalIndex,
1919
MultiIndex,
2020
NaT,
21+
Period,
2122
Series,
2223
Timedelta,
2324
Timestamp,
@@ -1317,6 +1318,22 @@ def obj(self):
13171318
return Series(timedelta_range("1 day", periods=4))
13181319

13191320

1321+
@pytest.mark.parametrize(
1322+
"val", ["foo", Period("2016", freq="Y"), Interval(1, 2, closed="both")]
1323+
)
1324+
@pytest.mark.parametrize("exp_dtype", [object])
1325+
class TestPeriodIntervalCoercion(CoercionTest):
1326+
# GH#45768
1327+
@pytest.fixture(
1328+
params=[
1329+
period_range("2016-01-01", periods=3, freq="D"),
1330+
interval_range(1, 5),
1331+
]
1332+
)
1333+
def obj(self, request):
1334+
return Series(request.param)
1335+
1336+
13201337
def test_20643():
13211338
# closed by GH#45121
13221339
orig = Series([0, 1, 2], index=["a", "b", "c"])

0 commit comments

Comments
 (0)