Skip to content

Commit bde2527

Browse files
jbrockmendeljreback
authored andcommitted
BUG: DTI/TDI/PI where accepting non-matching dtypes (#30791)
1 parent 425c2fb commit bde2527

File tree

5 files changed

+90
-107
lines changed

5 files changed

+90
-107
lines changed

pandas/core/arrays/datetimelike.py

+12-83
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
is_unsigned_integer_dtype,
3535
pandas_dtype,
3636
)
37-
from pandas.core.dtypes.generic import ABCIndexClass, ABCPeriodArray, ABCSeries
37+
from pandas.core.dtypes.generic import ABCSeries
3838
from pandas.core.dtypes.inference import is_array_like
3939
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna
4040

@@ -368,16 +368,19 @@ class TimelikeOps:
368368

369369
def _round(self, freq, mode, ambiguous, nonexistent):
370370
# round the local times
371-
values = _ensure_datetimelike_to_i8(self)
371+
if is_datetime64tz_dtype(self):
372+
# operate on naive timestamps, then convert back to aware
373+
naive = self.tz_localize(None)
374+
result = naive._round(freq, mode, ambiguous, nonexistent)
375+
aware = result.tz_localize(
376+
self.tz, ambiguous=ambiguous, nonexistent=nonexistent
377+
)
378+
return aware
379+
380+
values = self.view("i8")
372381
result = round_nsint64(values, mode, freq)
373382
result = self._maybe_mask_results(result, fill_value=NaT)
374-
375-
dtype = self.dtype
376-
if is_datetime64tz_dtype(self):
377-
dtype = None
378-
return self._ensure_localized(
379-
self._simple_new(result, dtype=dtype), ambiguous, nonexistent
380-
)
383+
return self._simple_new(result, dtype=self.dtype)
381384

382385
@Appender((_round_doc + _round_example).format(op="round"))
383386
def round(self, freq, ambiguous="raise", nonexistent="raise"):
@@ -1411,45 +1414,6 @@ def __isub__(self, other): # type: ignore
14111414
self._freq = result._freq
14121415
return self
14131416

1414-
# --------------------------------------------------------------
1415-
# Comparison Methods
1416-
1417-
def _ensure_localized(
1418-
self, arg, ambiguous="raise", nonexistent="raise", from_utc=False
1419-
):
1420-
"""
1421-
Ensure that we are re-localized.
1422-
1423-
This is for compat as we can then call this on all datetimelike
1424-
arrays generally (ignored for Period/Timedelta)
1425-
1426-
Parameters
1427-
----------
1428-
arg : Union[DatetimeLikeArray, DatetimeIndexOpsMixin, ndarray]
1429-
ambiguous : str, bool, or bool-ndarray, default 'raise'
1430-
nonexistent : str, default 'raise'
1431-
from_utc : bool, default False
1432-
If True, localize the i8 ndarray to UTC first before converting to
1433-
the appropriate tz. If False, localize directly to the tz.
1434-
1435-
Returns
1436-
-------
1437-
localized array
1438-
"""
1439-
1440-
# reconvert to local tz
1441-
tz = getattr(self, "tz", None)
1442-
if tz is not None:
1443-
if not isinstance(arg, type(self)):
1444-
arg = self._simple_new(arg)
1445-
if from_utc:
1446-
arg = arg.tz_localize("UTC").tz_convert(self.tz)
1447-
else:
1448-
arg = arg.tz_localize(
1449-
self.tz, ambiguous=ambiguous, nonexistent=nonexistent
1450-
)
1451-
return arg
1452-
14531417
# --------------------------------------------------------------
14541418
# Reductions
14551419

@@ -1687,38 +1651,3 @@ def maybe_infer_freq(freq):
16871651
freq_infer = True
16881652
freq = None
16891653
return freq, freq_infer
1690-
1691-
1692-
def _ensure_datetimelike_to_i8(other, to_utc=False):
1693-
"""
1694-
Helper for coercing an input scalar or array to i8.
1695-
1696-
Parameters
1697-
----------
1698-
other : 1d array
1699-
to_utc : bool, default False
1700-
If True, convert the values to UTC before extracting the i8 values
1701-
If False, extract the i8 values directly.
1702-
1703-
Returns
1704-
-------
1705-
i8 1d array
1706-
"""
1707-
from pandas import Index
1708-
1709-
if lib.is_scalar(other) and isna(other):
1710-
return iNaT
1711-
elif isinstance(other, (ABCPeriodArray, ABCIndexClass, DatetimeLikeArrayMixin)):
1712-
# convert tz if needed
1713-
if getattr(other, "tz", None) is not None:
1714-
if to_utc:
1715-
other = other.tz_convert("UTC")
1716-
else:
1717-
other = other.tz_localize(None)
1718-
else:
1719-
try:
1720-
return np.array(other, copy=False).view("i8")
1721-
except TypeError:
1722-
# period array cannot be coerced to int
1723-
other = Index(other)
1724-
return other.asi8

pandas/core/indexes/datetimelike.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616
from pandas.core.dtypes.common import (
1717
ensure_int64,
1818
is_bool_dtype,
19+
is_categorical_dtype,
1920
is_dtype_equal,
2021
is_float,
2122
is_integer,
2223
is_list_like,
2324
is_period_dtype,
2425
is_scalar,
26+
needs_i8_conversion,
2527
)
2628
from pandas.core.dtypes.concat import concat_compat
2729
from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries
30+
from pandas.core.dtypes.missing import isna
2831

2932
from pandas.core import algorithms
3033
from pandas.core.accessor import PandasDelegate
@@ -34,10 +37,7 @@
3437
ExtensionOpsMixin,
3538
TimedeltaArray,
3639
)
37-
from pandas.core.arrays.datetimelike import (
38-
DatetimeLikeArrayMixin,
39-
_ensure_datetimelike_to_i8,
40-
)
40+
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin
4141
import pandas.core.indexes.base as ibase
4242
from pandas.core.indexes.base import Index, _index_shared_docs
4343
from pandas.core.indexes.numeric import Int64Index
@@ -166,18 +166,6 @@ def equals(self, other) -> bool:
166166

167167
return np.array_equal(self.asi8, other.asi8)
168168

169-
def _ensure_localized(
170-
self, arg, ambiguous="raise", nonexistent="raise", from_utc=False
171-
):
172-
# See DatetimeLikeArrayMixin._ensure_localized.__doc__
173-
if getattr(self, "tz", None):
174-
# ensure_localized is only relevant for tz-aware DTI
175-
result = self._data._ensure_localized(
176-
arg, ambiguous=ambiguous, nonexistent=nonexistent, from_utc=from_utc
177-
)
178-
return type(self)._simple_new(result, name=self.name)
179-
return arg
180-
181169
@Appender(_index_shared_docs["contains"] % _index_doc_kwargs)
182170
def __contains__(self, key):
183171
try:
@@ -480,11 +468,27 @@ def repeat(self, repeats, axis=None):
480468

481469
@Appender(_index_shared_docs["where"] % _index_doc_kwargs)
482470
def where(self, cond, other=None):
483-
other = _ensure_datetimelike_to_i8(other, to_utc=True)
484-
values = _ensure_datetimelike_to_i8(self, to_utc=True)
485-
result = np.where(cond, values, other).astype("i8")
471+
values = self.view("i8")
472+
473+
if is_scalar(other) and isna(other):
474+
other = NaT.value
486475

487-
result = self._ensure_localized(result, from_utc=True)
476+
else:
477+
# Do type inference if necessary up front
478+
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
479+
other = Index(other)
480+
481+
if is_categorical_dtype(other):
482+
# e.g. we have a Categorical holding self.dtype
483+
if needs_i8_conversion(other.categories):
484+
other = other._internal_get_values()
485+
486+
if not is_dtype_equal(self.dtype, other.dtype):
487+
raise TypeError(f"Where requires matching dtype, not {other.dtype}")
488+
489+
other = other.view("i8")
490+
491+
result = np.where(cond, values, other).astype("i8")
488492
return self._shallow_copy(result)
489493

490494
def _summary(self, name=None):

pandas/tests/indexes/datetimes/test_indexing.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,32 @@ def test_where_other(self):
132132

133133
i2 = i.copy()
134134
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
135-
result = i.where(notna(i2), i2.values)
135+
result = i.where(notna(i2), i2._values)
136136
tm.assert_index_equal(result, i2)
137137

138+
def test_where_invalid_dtypes(self):
139+
dti = pd.date_range("20130101", periods=3, tz="US/Eastern")
140+
141+
i2 = dti.copy()
142+
i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist())
143+
144+
with pytest.raises(TypeError, match="Where requires matching dtype"):
145+
# passing tz-naive ndarray to tzaware DTI
146+
dti.where(notna(i2), i2.values)
147+
148+
with pytest.raises(TypeError, match="Where requires matching dtype"):
149+
# passing tz-aware DTI to tznaive DTI
150+
dti.tz_localize(None).where(notna(i2), i2)
151+
152+
with pytest.raises(TypeError, match="Where requires matching dtype"):
153+
dti.where(notna(i2), i2.tz_localize(None).to_period("D"))
154+
155+
with pytest.raises(TypeError, match="Where requires matching dtype"):
156+
dti.where(notna(i2), i2.asi8.view("timedelta64[ns]"))
157+
158+
with pytest.raises(TypeError, match="Where requires matching dtype"):
159+
dti.where(notna(i2), i2.asi8)
160+
138161
def test_where_tz(self):
139162
i = pd.date_range("20130101", periods=3, tz="US/Eastern")
140163
result = i.where(notna(i))

pandas/tests/indexes/period/test_indexing.py

+15
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,21 @@ def test_where_other(self):
235235
result = i.where(notna(i2), i2.values)
236236
tm.assert_index_equal(result, i2)
237237

238+
def test_where_invalid_dtypes(self):
239+
pi = period_range("20130101", periods=5, freq="D")
240+
241+
i2 = pi.copy()
242+
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + pi[2:].tolist(), freq="D")
243+
244+
with pytest.raises(TypeError, match="Where requires matching dtype"):
245+
pi.where(notna(i2), i2.asi8)
246+
247+
with pytest.raises(TypeError, match="Where requires matching dtype"):
248+
pi.where(notna(i2), i2.asi8.view("timedelta64[ns]"))
249+
250+
with pytest.raises(TypeError, match="Where requires matching dtype"):
251+
pi.where(notna(i2), i2.to_timestamp("S"))
252+
238253

239254
class TestTake:
240255
def test_take(self):

pandas/tests/indexes/timedeltas/test_indexing.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
import pandas as pd
7-
from pandas import Index, Timedelta, TimedeltaIndex, timedelta_range
7+
from pandas import Index, Timedelta, TimedeltaIndex, notna, timedelta_range
88
import pandas._testing as tm
99

1010

@@ -58,8 +58,20 @@ def test_timestamp_invalid_key(self, key):
5858

5959

6060
class TestWhere:
61-
# placeholder for symmetry with DatetimeIndex and PeriodIndex tests
62-
pass
61+
def test_where_invalid_dtypes(self):
62+
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
63+
64+
i2 = tdi.copy()
65+
i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist())
66+
67+
with pytest.raises(TypeError, match="Where requires matching dtype"):
68+
tdi.where(notna(i2), i2.asi8)
69+
70+
with pytest.raises(TypeError, match="Where requires matching dtype"):
71+
tdi.where(notna(i2), i2 + pd.Timestamp.now())
72+
73+
with pytest.raises(TypeError, match="Where requires matching dtype"):
74+
tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D"))
6375

6476

6577
class TestTake:

0 commit comments

Comments
 (0)