Skip to content

Commit f2ff676

Browse files
authored
BUG: Series[dt64tz].__setitem__(mask, different_tz) raises (#39656)
1 parent 4b066bb commit f2ff676

File tree

10 files changed

+154
-98
lines changed

10 files changed

+154
-98
lines changed

pandas/core/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4528,7 +4528,7 @@ def putmask(self, mask, value):
45284528
if not mask.any():
45294529
return self.copy()
45304530

4531-
if value is None:
4531+
if value is None and (self._is_numeric_dtype or self.dtype == object):
45324532
value = self._na_value
45334533
try:
45344534
converted = self._validate_fill_value(value)

pandas/core/indexes/extension.py

-5
Original file line numberDiff line numberDiff line change
@@ -353,11 +353,6 @@ def insert(self: _T, loc: int, item) -> _T:
353353
new_arr = arr._from_backing_data(new_vals)
354354
return type(self)._simple_new(new_arr, name=self.name)
355355

356-
@doc(Index.where)
357-
def where(self: _T, cond: np.ndarray, other=None) -> _T:
358-
res_values = self._data.where(cond, other)
359-
return type(self)._simple_new(res_values, name=self.name)
360-
361356
def putmask(self, mask, value):
362357
res_values = self._data.copy()
363358
try:

pandas/core/internals/blocks.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1014,14 +1014,6 @@ def putmask(self, mask, new) -> List[Block]:
10141014
new = self.fill_value
10151015

10161016
if self._can_hold_element(new):
1017-
if self.dtype.kind in ["m", "M"]:
1018-
arr = self.array_values()
1019-
arr = cast("NDArrayBackedExtensionArray", arr)
1020-
if transpose:
1021-
arr = arr.T
1022-
arr.putmask(mask, new)
1023-
return [self]
1024-
10251017
if transpose:
10261018
new_values = new_values.T
10271019

@@ -2025,6 +2017,18 @@ def to_native_types(self, na_rep="NaT", **kwargs):
20252017
result = arr._format_native_types(na_rep=na_rep, **kwargs)
20262018
return self.make_block(result)
20272019

2020+
def putmask(self, mask, new) -> List[Block]:
2021+
mask = _extract_bool_array(mask)
2022+
2023+
if not self._can_hold_element(new):
2024+
return self.astype(object).putmask(mask, new)
2025+
2026+
# TODO(EA2D): reshape unnecessary with 2D EAs
2027+
arr = self.array_values().reshape(self.shape)
2028+
arr = cast("NDArrayBackedExtensionArray", arr)
2029+
arr.T.putmask(mask, new)
2030+
return [self]
2031+
20282032
def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
20292033
# TODO(EA2D): reshape unnecessary with 2D EAs
20302034
arr = self.array_values().reshape(self.shape)
@@ -2099,6 +2103,7 @@ class DatetimeTZBlock(ExtensionBlock, DatetimeBlock):
20992103
diff = DatetimeBlock.diff
21002104
fill_value = np.datetime64("NaT", "ns")
21012105
where = DatetimeBlock.where
2106+
putmask = DatetimeLikeBlockMixin.putmask
21022107

21032108
array_values = ExtensionBlock.array_values
21042109

pandas/tests/indexes/categorical/test_indexing.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,11 @@ def test_where_non_categories(self):
305305
ci = CategoricalIndex(["a", "b", "c", "d"])
306306
mask = np.array([True, False, True, False])
307307

308-
msg = "Cannot setitem on a Categorical with a new category"
309-
with pytest.raises(ValueError, match=msg):
310-
ci.where(mask, 2)
308+
result = ci.where(mask, 2)
309+
expected = Index(["a", 2, "c", 2], dtype=object)
310+
tm.assert_index_equal(result, expected)
311311

312+
msg = "Cannot setitem on a Categorical with a new category"
312313
with pytest.raises(ValueError, match=msg):
313314
# Test the Categorical method directly
314315
ci._data.where(mask, 2)

pandas/tests/indexes/datetimelike.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ def test_where_cast_str(self):
117117
result = index.where(mask, [str(index[0])])
118118
tm.assert_index_equal(result, expected)
119119

120-
msg = "value should be a '.*', 'NaT', or array of those"
121-
with pytest.raises(TypeError, match=msg):
122-
index.where(mask, "foo")
120+
expected = index.astype(object).where(mask, "foo")
121+
result = index.where(mask, "foo")
122+
tm.assert_index_equal(result, expected)
123123

124-
with pytest.raises(TypeError, match=msg):
125-
index.where(mask, ["foo"])
124+
result = index.where(mask, ["foo"])
125+
tm.assert_index_equal(result, expected)

pandas/tests/indexes/datetimes/test_indexing.py

+38-22
Original file line numberDiff line numberDiff line change
@@ -175,40 +175,56 @@ def test_where_other(self):
175175
def test_where_invalid_dtypes(self):
176176
dti = date_range("20130101", periods=3, tz="US/Eastern")
177177

178-
i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist())
178+
tail = dti[2:].tolist()
179+
i2 = Index([pd.NaT, pd.NaT] + tail)
179180

180-
msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
181-
msg2 = "Cannot compare tz-naive and tz-aware datetime-like objects"
182-
with pytest.raises(TypeError, match=msg2):
183-
# passing tz-naive ndarray to tzaware DTI
184-
dti.where(notna(i2), i2.values)
181+
mask = notna(i2)
185182

186-
with pytest.raises(TypeError, match=msg2):
187-
# passing tz-aware DTI to tznaive DTI
188-
dti.tz_localize(None).where(notna(i2), i2)
183+
# passing tz-naive ndarray to tzaware DTI
184+
result = dti.where(mask, i2.values)
185+
expected = Index([pd.NaT.asm8, pd.NaT.asm8] + tail, dtype=object)
186+
tm.assert_index_equal(result, expected)
189187

190-
with pytest.raises(TypeError, match=msg):
191-
dti.where(notna(i2), i2.tz_localize(None).to_period("D"))
188+
# passing tz-aware DTI to tznaive DTI
189+
naive = dti.tz_localize(None)
190+
result = naive.where(mask, i2)
191+
expected = Index([i2[0], i2[1]] + naive[2:].tolist(), dtype=object)
192+
tm.assert_index_equal(result, expected)
192193

193-
with pytest.raises(TypeError, match=msg):
194-
dti.where(notna(i2), i2.asi8.view("timedelta64[ns]"))
194+
pi = i2.tz_localize(None).to_period("D")
195+
result = dti.where(mask, pi)
196+
expected = Index([pi[0], pi[1]] + tail, dtype=object)
197+
tm.assert_index_equal(result, expected)
195198

196-
with pytest.raises(TypeError, match=msg):
197-
dti.where(notna(i2), i2.asi8)
199+
tda = i2.asi8.view("timedelta64[ns]")
200+
result = dti.where(mask, tda)
201+
expected = Index([tda[0], tda[1]] + tail, dtype=object)
202+
assert isinstance(expected[0], np.timedelta64)
203+
tm.assert_index_equal(result, expected)
198204

199-
with pytest.raises(TypeError, match=msg):
200-
# non-matching scalar
201-
dti.where(notna(i2), pd.Timedelta(days=4))
205+
result = dti.where(mask, i2.asi8)
206+
expected = Index([pd.NaT.value, pd.NaT.value] + tail, dtype=object)
207+
assert isinstance(expected[0], int)
208+
tm.assert_index_equal(result, expected)
209+
210+
# non-matching scalar
211+
td = pd.Timedelta(days=4)
212+
result = dti.where(mask, td)
213+
expected = Index([td, td] + tail, dtype=object)
214+
assert expected[0] is td
215+
tm.assert_index_equal(result, expected)
202216

203217
def test_where_mismatched_nat(self, tz_aware_fixture):
204218
tz = tz_aware_fixture
205219
dti = date_range("2013-01-01", periods=3, tz=tz)
206220
cond = np.array([True, False, True])
207221

208-
msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
209-
with pytest.raises(TypeError, match=msg):
210-
# wrong-dtyped NaT
211-
dti.where(cond, np.timedelta64("NaT", "ns"))
222+
tdnat = np.timedelta64("NaT", "ns")
223+
expected = Index([dti[0], tdnat, dti[2]], dtype=object)
224+
assert expected[1] is tdnat
225+
226+
result = dti.where(cond, tdnat)
227+
tm.assert_index_equal(result, expected)
212228

213229
def test_where_tz(self):
214230
i = date_range("20130101", periods=3, tz="US/Eastern")

pandas/tests/indexes/period/test_indexing.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -603,30 +603,42 @@ def test_where_other(self):
603603
def test_where_invalid_dtypes(self):
604604
pi = period_range("20130101", periods=5, freq="D")
605605

606-
i2 = PeriodIndex([NaT, NaT] + pi[2:].tolist(), freq="D")
606+
tail = pi[2:].tolist()
607+
i2 = PeriodIndex([NaT, NaT] + tail, freq="D")
608+
mask = notna(i2)
607609

608-
msg = "value should be a 'Period', 'NaT', or array of those"
609-
with pytest.raises(TypeError, match=msg):
610-
pi.where(notna(i2), i2.asi8)
610+
result = pi.where(mask, i2.asi8)
611+
expected = pd.Index([NaT.value, NaT.value] + tail, dtype=object)
612+
assert isinstance(expected[0], int)
613+
tm.assert_index_equal(result, expected)
611614

612-
with pytest.raises(TypeError, match=msg):
613-
pi.where(notna(i2), i2.asi8.view("timedelta64[ns]"))
615+
tdi = i2.asi8.view("timedelta64[ns]")
616+
expected = pd.Index([tdi[0], tdi[1]] + tail, dtype=object)
617+
assert isinstance(expected[0], np.timedelta64)
618+
result = pi.where(mask, tdi)
619+
tm.assert_index_equal(result, expected)
614620

615-
with pytest.raises(TypeError, match=msg):
616-
pi.where(notna(i2), i2.to_timestamp("S"))
621+
dti = i2.to_timestamp("S")
622+
expected = pd.Index([dti[0], dti[1]] + tail, dtype=object)
623+
assert expected[0] is NaT
624+
result = pi.where(mask, dti)
625+
tm.assert_index_equal(result, expected)
617626

618-
with pytest.raises(TypeError, match=msg):
619-
# non-matching scalar
620-
pi.where(notna(i2), Timedelta(days=4))
627+
td = Timedelta(days=4)
628+
expected = pd.Index([td, td] + tail, dtype=object)
629+
assert expected[0] == td
630+
result = pi.where(mask, td)
631+
tm.assert_index_equal(result, expected)
621632

622633
def test_where_mismatched_nat(self):
623634
pi = period_range("20130101", periods=5, freq="D")
624635
cond = np.array([True, False, True, True, False])
625636

626-
msg = "value should be a 'Period', 'NaT', or array of those"
627-
with pytest.raises(TypeError, match=msg):
628-
# wrong-dtyped NaT
629-
pi.where(cond, np.timedelta64("NaT", "ns"))
637+
tdnat = np.timedelta64("NaT", "ns")
638+
expected = pd.Index([pi[0], tdnat, pi[2], pi[3], tdnat], dtype=object)
639+
assert expected[1] is tdnat
640+
result = pi.where(cond, tdnat)
641+
tm.assert_index_equal(result, expected)
630642

631643

632644
class TestTake:

pandas/tests/indexes/timedeltas/test_indexing.py

+24-15
Original file line numberDiff line numberDiff line change
@@ -148,30 +148,39 @@ def test_where_doesnt_retain_freq(self):
148148
def test_where_invalid_dtypes(self):
149149
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
150150

151-
i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist())
151+
tail = tdi[2:].tolist()
152+
i2 = Index([pd.NaT, pd.NaT] + tail)
153+
mask = notna(i2)
152154

153-
msg = "value should be a 'Timedelta', 'NaT', or array of those"
154-
with pytest.raises(TypeError, match=msg):
155-
tdi.where(notna(i2), i2.asi8)
155+
expected = Index([pd.NaT.value, pd.NaT.value] + tail, dtype=object, name="idx")
156+
assert isinstance(expected[0], int)
157+
result = tdi.where(mask, i2.asi8)
158+
tm.assert_index_equal(result, expected)
156159

157-
with pytest.raises(TypeError, match=msg):
158-
tdi.where(notna(i2), i2 + pd.Timestamp.now())
160+
ts = i2 + pd.Timestamp.now()
161+
expected = Index([ts[0], ts[1]] + tail, dtype=object, name="idx")
162+
result = tdi.where(mask, ts)
163+
tm.assert_index_equal(result, expected)
159164

160-
with pytest.raises(TypeError, match=msg):
161-
tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D"))
165+
per = (i2 + pd.Timestamp.now()).to_period("D")
166+
expected = Index([per[0], per[1]] + tail, dtype=object, name="idx")
167+
result = tdi.where(mask, per)
168+
tm.assert_index_equal(result, expected)
162169

163-
with pytest.raises(TypeError, match=msg):
164-
# non-matching scalar
165-
tdi.where(notna(i2), pd.Timestamp.now())
170+
ts = pd.Timestamp.now()
171+
expected = Index([ts, ts] + tail, dtype=object, name="idx")
172+
result = tdi.where(mask, ts)
173+
tm.assert_index_equal(result, expected)
166174

167175
def test_where_mismatched_nat(self):
168176
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
169177
cond = np.array([True, False, False])
170178

171-
msg = "value should be a 'Timedelta', 'NaT', or array of those"
172-
with pytest.raises(TypeError, match=msg):
173-
# wrong-dtyped NaT
174-
tdi.where(cond, np.datetime64("NaT", "ns"))
179+
dtnat = np.datetime64("NaT", "ns")
180+
expected = Index([tdi[0], dtnat, dtnat], dtype=object, name="idx")
181+
assert expected[2] is dtnat
182+
result = tdi.where(cond, dtnat)
183+
tm.assert_index_equal(result, expected)
175184

176185

177186
class TestTake:

pandas/tests/indexing/test_coercion.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -802,10 +802,13 @@ def test_where_index_timedelta64(self, value):
802802
result = tdi.where(cond, value)
803803
tm.assert_index_equal(result, expected)
804804

805-
msg = "value should be a 'Timedelta', 'NaT', or array of thos"
806-
with pytest.raises(TypeError, match=msg):
807-
# wrong-dtyped NaT
808-
tdi.where(cond, np.datetime64("NaT", "ns"))
805+
# wrong-dtyped NaT
806+
dtnat = np.datetime64("NaT", "ns")
807+
expected = pd.Index([tdi[0], dtnat, dtnat, tdi[3]], dtype=object)
808+
assert expected[1] is dtnat
809+
810+
result = tdi.where(cond, dtnat)
811+
tm.assert_index_equal(result, expected)
809812

810813
def test_where_index_period(self):
811814
dti = pd.date_range("2016-01-01", periods=3, freq="QS")
@@ -825,14 +828,16 @@ def test_where_index_period(self):
825828
expected = pd.PeriodIndex([other[0], pi[1], other[2]])
826829
tm.assert_index_equal(result, expected)
827830

828-
# Passing a mismatched scalar
829-
msg = "value should be a 'Period', 'NaT', or array of those"
830-
with pytest.raises(TypeError, match=msg):
831-
pi.where(cond, pd.Timedelta(days=4))
831+
# Passing a mismatched scalar -> casts to object
832+
td = pd.Timedelta(days=4)
833+
expected = pd.Index([td, pi[1], td], dtype=object)
834+
result = pi.where(cond, td)
835+
tm.assert_index_equal(result, expected)
832836

833-
msg = r"Input has different freq=D from PeriodArray\(freq=Q-DEC\)"
834-
with pytest.raises(ValueError, match=msg):
835-
pi.where(cond, pd.Period("2020-04-21", "D"))
837+
per = pd.Period("2020-04-21", "D")
838+
expected = pd.Index([per, pi[1], per], dtype=object)
839+
result = pi.where(cond, per)
840+
tm.assert_index_equal(result, expected)
836841

837842

838843
class TestFillnaSeriesCoercion(CoercionBase):

pandas/tests/series/indexing/test_setitem.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,6 @@ def test_setitem_with_string_index(self):
5050
assert ser.Date == date.today()
5151
assert ser["Date"] == date.today()
5252

53-
def test_setitem_with_different_tz_casts_to_object(self):
54-
# GH#24024
55-
ser = Series(date_range("2000", periods=2, tz="US/Central"))
56-
ser[0] = Timestamp("2000", tz="US/Eastern")
57-
expected = Series(
58-
[
59-
Timestamp("2000-01-01 00:00:00-05:00", tz="US/Eastern"),
60-
Timestamp("2000-01-02 00:00:00-06:00", tz="US/Central"),
61-
],
62-
dtype=object,
63-
)
64-
tm.assert_series_equal(ser, expected)
65-
6653
def test_setitem_tuple_with_datetimetz_values(self):
6754
# GH#20441
6855
arr = date_range("2017", periods=4, tz="US/Eastern")
@@ -646,3 +633,29 @@ def val(self, request):
646633
@pytest.fixture
647634
def is_inplace(self):
648635
return True
636+
637+
638+
class TestSetitemMismatchedTZCastsToObject(SetitemCastingEquivalents):
639+
# GH#24024
640+
@pytest.fixture
641+
def obj(self):
642+
return Series(date_range("2000", periods=2, tz="US/Central"))
643+
644+
@pytest.fixture
645+
def val(self):
646+
return Timestamp("2000", tz="US/Eastern")
647+
648+
@pytest.fixture
649+
def key(self):
650+
return 0
651+
652+
@pytest.fixture
653+
def expected(self):
654+
expected = Series(
655+
[
656+
Timestamp("2000-01-01 00:00:00-05:00", tz="US/Eastern"),
657+
Timestamp("2000-01-02 00:00:00-06:00", tz="US/Central"),
658+
],
659+
dtype=object,
660+
)
661+
return expected

0 commit comments

Comments
 (0)