Skip to content

Commit 6645079

Browse files
jbrockmendelrhshadrach
authored andcommitted
REF: de-duplicate listlike validation in DTA._validate_foo (pandas-dev#33908)
1 parent 18bc891 commit 6645079

File tree

1 file changed

+57
-59
lines changed

1 file changed

+57
-59
lines changed

pandas/core/arrays/datetimelike.py

+57-59
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import datetime, timedelta
22
import operator
3-
from typing import Any, Sequence, Type, TypeVar, Union, cast
3+
from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast
44
import warnings
55

66
import numpy as np
@@ -10,7 +10,7 @@
1010
from pandas._libs.tslibs.period import DIFFERENT_FREQ, IncompatibleFrequency, Period
1111
from pandas._libs.tslibs.timedeltas import delta_to_nanoseconds
1212
from pandas._libs.tslibs.timestamps import RoundTo, round_nsint64
13-
from pandas._typing import DatetimeLikeScalar
13+
from pandas._typing import DatetimeLikeScalar, DtypeObj
1414
from pandas.compat import set_function_name
1515
from pandas.compat.numpy import function as nv
1616
from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning
@@ -86,24 +86,10 @@ def _validate_comparison_value(self, other):
8686
raise ValueError("Lengths must match")
8787

8888
else:
89-
if isinstance(other, list):
90-
# TODO: could use pd.Index to do inference?
91-
other = np.array(other)
92-
93-
if not isinstance(other, (np.ndarray, type(self))):
94-
raise InvalidComparison(other)
95-
96-
elif is_object_dtype(other.dtype):
97-
pass
98-
99-
elif not type(self)._is_recognized_dtype(other.dtype):
100-
raise InvalidComparison(other)
101-
102-
else:
103-
# For PeriodDType this casting is unnecessary
104-
# TODO: use Index to do inference?
105-
other = type(self)._from_sequence(other)
106-
self._check_compatible_with(other)
89+
try:
90+
other = self._validate_listlike(other, opname, allow_object=True)
91+
except TypeError as err:
92+
raise InvalidComparison(other) from err
10793

10894
return other
10995

@@ -451,6 +437,8 @@ class DatetimeLikeArrayMixin(
451437
_generate_range
452438
"""
453439

440+
_is_recognized_dtype: Callable[[DtypeObj], bool]
441+
454442
# ------------------------------------------------------------------
455443
# NDArrayBackedExtensionArray compat
456444

@@ -770,6 +758,48 @@ def _validate_shift_value(self, fill_value):
770758

771759
return self._unbox(fill_value)
772760

761+
def _validate_listlike(
762+
self,
763+
value,
764+
opname: str,
765+
cast_str: bool = False,
766+
cast_cat: bool = False,
767+
allow_object: bool = False,
768+
):
769+
if isinstance(value, type(self)):
770+
return value
771+
772+
# Do type inference if necessary up front
773+
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
774+
value = array(value)
775+
value = extract_array(value, extract_numpy=True)
776+
777+
if cast_str and is_dtype_equal(value.dtype, "string"):
778+
# We got a StringArray
779+
try:
780+
# TODO: Could use from_sequence_of_strings if implemented
781+
# Note: passing dtype is necessary for PeriodArray tests
782+
value = type(self)._from_sequence(value, dtype=self.dtype)
783+
except ValueError:
784+
pass
785+
786+
if cast_cat and is_categorical_dtype(value.dtype):
787+
# e.g. we have a Categorical holding self.dtype
788+
if is_dtype_equal(value.categories.dtype, self.dtype):
789+
# TODO: do we need equal dtype or just comparable?
790+
value = value._internal_get_values()
791+
792+
if allow_object and is_object_dtype(value.dtype):
793+
pass
794+
795+
elif not type(self)._is_recognized_dtype(value.dtype):
796+
raise TypeError(
797+
f"{opname} requires compatible dtype or scalar, "
798+
f"not {type(value).__name__}"
799+
)
800+
801+
return value
802+
773803
def _validate_searchsorted_value(self, value):
774804
if isinstance(value, str):
775805
try:
@@ -785,41 +815,19 @@ def _validate_searchsorted_value(self, value):
785815
elif isinstance(value, self._recognized_scalars):
786816
value = self._scalar_type(value)
787817

788-
elif isinstance(value, type(self)):
789-
pass
790-
791-
elif is_list_like(value) and not isinstance(value, type(self)):
792-
value = array(value)
793-
794-
if not type(self)._is_recognized_dtype(value.dtype):
795-
raise TypeError(
796-
"searchsorted requires compatible dtype or scalar, "
797-
f"not {type(value).__name__}"
798-
)
818+
elif not is_list_like(value):
819+
raise TypeError(f"Unexpected type for 'value': {type(value)}")
799820

800821
else:
801-
raise TypeError(f"Unexpected type for 'value': {type(value)}")
822+
# TODO: cast_str? we accept it for scalar
823+
value = self._validate_listlike(value, "searchsorted")
802824

803825
return self._unbox(value)
804826

805827
def _validate_setitem_value(self, value):
806828

807829
if is_list_like(value):
808-
value = array(value)
809-
if is_dtype_equal(value.dtype, "string"):
810-
# We got a StringArray
811-
try:
812-
# TODO: Could use from_sequence_of_strings if implemented
813-
# Note: passing dtype is necessary for PeriodArray tests
814-
value = type(self)._from_sequence(value, dtype=self.dtype)
815-
except ValueError:
816-
pass
817-
818-
if not type(self)._is_recognized_dtype(value.dtype):
819-
raise TypeError(
820-
"setitem requires compatible dtype or scalar, "
821-
f"not {type(value).__name__}"
822-
)
830+
value = self._validate_listlike(value, "setitem", cast_str=True)
823831

824832
elif isinstance(value, self._recognized_scalars):
825833
value = self._scalar_type(value)
@@ -860,18 +868,8 @@ def _validate_where_value(self, other):
860868
raise TypeError(f"Where requires matching dtype, not {type(other)}")
861869

862870
else:
863-
# Do type inference if necessary up front
864-
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
865-
other = array(other)
866-
other = extract_array(other, extract_numpy=True)
867-
868-
if is_categorical_dtype(other.dtype):
869-
# e.g. we have a Categorical holding self.dtype
870-
if is_dtype_equal(other.categories.dtype, self.dtype):
871-
other = other._internal_get_values()
872-
873-
if not type(self)._is_recognized_dtype(other.dtype):
874-
raise TypeError(f"Where requires matching dtype, not {other.dtype}")
871+
other = self._validate_listlike(other, "where", cast_cat=True)
872+
self._check_compatible_with(other, setitem=True)
875873

876874
self._check_compatible_with(other, setitem=True)
877875
return self._unbox(other)

0 commit comments

Comments
 (0)