Skip to content

Commit 024925a

Browse files
jbrockmendeljreback
authored andcommitted
REF: handle searchsorted casting within DatetimeLikeArray (#30950)
1 parent 8b754fc commit 024925a

File tree

8 files changed

+50
-72
lines changed

8 files changed

+50
-72
lines changed

doc/source/whatsnew/v1.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Categorical
6060
Datetimelike
6161
^^^^^^^^^^^^
6262
- Bug in :class:`Timestamp` where constructing :class:`Timestamp` from ambiguous epoch time and calling constructor again changed :meth:`Timestamp.value` property (:issue:`24329`)
63-
-
63+
- :meth:`DatetimeArray.searchsorted`, :meth:`TimedeltaArray.searchsorted`, :meth:`PeriodArray.searchsorted` not recognizing non-pandas scalars and incorrectly raising ``ValueError`` instead of ``TypeError`` (:issue:`30950`)
6464
-
6565

6666
Timedelta

pandas/core/arrays/datetimelike.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -743,17 +743,36 @@ def searchsorted(self, value, side="left", sorter=None):
743743
Array of insertion points with the same shape as `value`.
744744
"""
745745
if isinstance(value, str):
746-
value = self._scalar_from_string(value)
746+
try:
747+
value = self._scalar_from_string(value)
748+
except ValueError:
749+
raise TypeError("searchsorted requires compatible dtype or scalar")
750+
751+
elif is_valid_nat_for_dtype(value, self.dtype):
752+
value = NaT
753+
754+
elif isinstance(value, self._recognized_scalars):
755+
value = self._scalar_type(value)
756+
757+
elif isinstance(value, np.ndarray):
758+
if not type(self)._is_recognized_dtype(value):
759+
raise TypeError(
760+
"searchsorted requires compatible dtype or scalar, "
761+
f"not {type(value).__name__}"
762+
)
763+
value = type(self)(value)
764+
self._check_compatible_with(value)
747765

748-
if not (isinstance(value, (self._scalar_type, type(self))) or isna(value)):
749-
raise ValueError(f"Unexpected type for 'value': {type(value)}")
766+
if not (isinstance(value, (self._scalar_type, type(self))) or (value is NaT)):
767+
raise TypeError(f"Unexpected type for 'value': {type(value)}")
750768

751-
self._check_compatible_with(value)
752769
if isinstance(value, type(self)):
770+
self._check_compatible_with(value)
753771
value = value.asi8
754772
else:
755773
value = self._unbox_scalar(value)
756774

775+
# TODO: Use datetime64 semantics for sorting, xref GH#29844
757776
return self.asi8.searchsorted(value, side=side, sorter=sorter)
758777

759778
def repeat(self, repeats, *args, **kwargs):

pandas/core/indexes/datetimes.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -833,24 +833,13 @@ def slice_indexer(self, start=None, end=None, step=None, kind=None):
833833
@Substitution(klass="DatetimeIndex")
834834
@Appender(_shared_docs["searchsorted"])
835835
def searchsorted(self, value, side="left", sorter=None):
836-
if isinstance(value, (np.ndarray, Index)):
837-
if not type(self._data)._is_recognized_dtype(value):
838-
raise TypeError(
839-
"searchsorted requires compatible dtype or scalar, "
840-
f"not {type(value).__name__}"
841-
)
842-
value = type(self._data)(value)
843-
self._data._check_compatible_with(value)
844-
845-
elif isinstance(value, self._data._recognized_scalars):
846-
self._data._check_compatible_with(value)
847-
value = self._data._scalar_type(value)
848-
849-
elif not isinstance(value, DatetimeArray):
836+
if isinstance(value, str):
850837
raise TypeError(
851838
"searchsorted requires compatible dtype or scalar, "
852839
f"not {type(value).__name__}"
853840
)
841+
if isinstance(value, Index):
842+
value = value._data
854843

855844
return self._data.searchsorted(value, side=side)
856845

pandas/core/indexes/period.py

-12
Original file line numberDiff line numberDiff line change
@@ -470,18 +470,6 @@ def astype(self, dtype, copy=True, how="start"):
470470
@Substitution(klass="PeriodIndex")
471471
@Appender(_shared_docs["searchsorted"])
472472
def searchsorted(self, value, side="left", sorter=None):
473-
if isinstance(value, Period) or value is NaT:
474-
self._data._check_compatible_with(value)
475-
elif isinstance(value, str):
476-
try:
477-
value = Period(value, freq=self.freq)
478-
except DateParseError:
479-
raise KeyError(f"Cannot interpret '{value}' as period")
480-
elif not isinstance(value, PeriodArray):
481-
raise TypeError(
482-
"PeriodIndex.searchsorted requires either a Period or PeriodArray"
483-
)
484-
485473
return self._data.searchsorted(value, side=side, sorter=sorter)
486474

487475
@property

pandas/core/indexes/timedeltas.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -347,24 +347,13 @@ def _partial_td_slice(self, key):
347347
@Substitution(klass="TimedeltaIndex")
348348
@Appender(_shared_docs["searchsorted"])
349349
def searchsorted(self, value, side="left", sorter=None):
350-
if isinstance(value, (np.ndarray, Index)):
351-
if not type(self._data)._is_recognized_dtype(value):
352-
raise TypeError(
353-
"searchsorted requires compatible dtype or scalar, "
354-
f"not {type(value).__name__}"
355-
)
356-
value = type(self._data)(value)
357-
self._data._check_compatible_with(value)
358-
359-
elif isinstance(value, self._data._recognized_scalars):
360-
self._data._check_compatible_with(value)
361-
value = self._data._scalar_type(value)
362-
363-
elif not isinstance(value, TimedeltaArray):
350+
if isinstance(value, str):
364351
raise TypeError(
365352
"searchsorted requires compatible dtype or scalar, "
366353
f"not {type(value).__name__}"
367354
)
355+
if isinstance(value, Index):
356+
value = value._data
368357

369358
return self._data.searchsorted(value, side=side, sorter=sorter)
370359

pandas/tests/arrays/test_datetimes.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -331,25 +331,19 @@ def test_searchsorted_tzawareness_compat(self, index):
331331
pd.Timestamp.now().to_period("D"),
332332
],
333333
)
334-
@pytest.mark.parametrize(
335-
"index",
336-
[
337-
True,
338-
pytest.param(
339-
False,
340-
marks=pytest.mark.xfail(
341-
reason="Raises ValueError instead of TypeError", raises=ValueError
342-
),
343-
),
344-
],
345-
)
334+
@pytest.mark.parametrize("index", [True, False])
346335
def test_searchsorted_invalid_types(self, other, index):
347336
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
348337
arr = DatetimeArray(data, freq="D")
349338
if index:
350339
arr = pd.Index(arr)
351340

352-
msg = "searchsorted requires compatible dtype or scalar"
341+
msg = "|".join(
342+
[
343+
"searchsorted requires compatible dtype or scalar",
344+
"Unexpected type for 'value'",
345+
]
346+
)
353347
with pytest.raises(TypeError, match=msg):
354348
arr.searchsorted(other)
355349

pandas/tests/arrays/test_timedeltas.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -154,25 +154,19 @@ def test_setitem_objects(self, obj):
154154
pd.Timestamp.now().to_period("D"),
155155
],
156156
)
157-
@pytest.mark.parametrize(
158-
"index",
159-
[
160-
True,
161-
pytest.param(
162-
False,
163-
marks=pytest.mark.xfail(
164-
reason="Raises ValueError instead of TypeError", raises=ValueError
165-
),
166-
),
167-
],
168-
)
157+
@pytest.mark.parametrize("index", [True, False])
169158
def test_searchsorted_invalid_types(self, other, index):
170159
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
171160
arr = TimedeltaArray(data, freq="D")
172161
if index:
173162
arr = pd.Index(arr)
174163

175-
msg = "searchsorted requires compatible dtype or scalar"
164+
msg = "|".join(
165+
[
166+
"searchsorted requires compatible dtype or scalar",
167+
"Unexpected type for 'value'",
168+
]
169+
)
176170
with pytest.raises(TypeError, match=msg):
177171
arr.searchsorted(other)
178172

pandas/tests/indexes/period/test_tools.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,12 @@ def test_searchsorted_invalid(self):
249249

250250
other = np.array([0, 1], dtype=np.int64)
251251

252-
msg = "requires either a Period or PeriodArray"
252+
msg = "|".join(
253+
[
254+
"searchsorted requires compatible dtype or scalar",
255+
"Unexpected type for 'value'",
256+
]
257+
)
253258
with pytest.raises(TypeError, match=msg):
254259
pidx.searchsorted(other)
255260

0 commit comments

Comments
 (0)