1
1
from datetime import datetime , timedelta
2
2
import operator
3
- from typing import Any , Sequence , Type , TypeVar , Union , cast
3
+ from typing import Any , Callable , Sequence , Type , TypeVar , Union , cast
4
4
import warnings
5
5
6
6
import numpy as np
10
10
from pandas ._libs .tslibs .period import DIFFERENT_FREQ , IncompatibleFrequency , Period
11
11
from pandas ._libs .tslibs .timedeltas import delta_to_nanoseconds
12
12
from pandas ._libs .tslibs .timestamps import RoundTo , round_nsint64
13
- from pandas ._typing import DatetimeLikeScalar
13
+ from pandas ._typing import DatetimeLikeScalar , DtypeObj
14
14
from pandas .compat import set_function_name
15
15
from pandas .compat .numpy import function as nv
16
16
from pandas .errors import AbstractMethodError , NullFrequencyError , PerformanceWarning
@@ -86,24 +86,10 @@ def _validate_comparison_value(self, other):
86
86
raise ValueError ("Lengths must match" )
87
87
88
88
else :
89
- if isinstance (other , list ):
90
- # TODO: could use pd.Index to do inference?
91
- other = np .array (other )
92
-
93
- if not isinstance (other , (np .ndarray , type (self ))):
94
- raise InvalidComparison (other )
95
-
96
- elif is_object_dtype (other .dtype ):
97
- pass
98
-
99
- elif not type (self )._is_recognized_dtype (other .dtype ):
100
- raise InvalidComparison (other )
101
-
102
- else :
103
- # For PeriodDType this casting is unnecessary
104
- # TODO: use Index to do inference?
105
- other = type (self )._from_sequence (other )
106
- self ._check_compatible_with (other )
89
+ try :
90
+ other = self ._validate_listlike (other , opname , allow_object = True )
91
+ except TypeError as err :
92
+ raise InvalidComparison (other ) from err
107
93
108
94
return other
109
95
@@ -451,6 +437,8 @@ class DatetimeLikeArrayMixin(
451
437
_generate_range
452
438
"""
453
439
440
+ _is_recognized_dtype : Callable [[DtypeObj ], bool ]
441
+
454
442
# ------------------------------------------------------------------
455
443
# NDArrayBackedExtensionArray compat
456
444
@@ -770,6 +758,48 @@ def _validate_shift_value(self, fill_value):
770
758
771
759
return self ._unbox (fill_value )
772
760
761
+ def _validate_listlike (
762
+ self ,
763
+ value ,
764
+ opname : str ,
765
+ cast_str : bool = False ,
766
+ cast_cat : bool = False ,
767
+ allow_object : bool = False ,
768
+ ):
769
+ if isinstance (value , type (self )):
770
+ return value
771
+
772
+ # Do type inference if necessary up front
773
+ # e.g. we passed PeriodIndex.values and got an ndarray of Periods
774
+ value = array (value )
775
+ value = extract_array (value , extract_numpy = True )
776
+
777
+ if cast_str and is_dtype_equal (value .dtype , "string" ):
778
+ # We got a StringArray
779
+ try :
780
+ # TODO: Could use from_sequence_of_strings if implemented
781
+ # Note: passing dtype is necessary for PeriodArray tests
782
+ value = type (self )._from_sequence (value , dtype = self .dtype )
783
+ except ValueError :
784
+ pass
785
+
786
+ if cast_cat and is_categorical_dtype (value .dtype ):
787
+ # e.g. we have a Categorical holding self.dtype
788
+ if is_dtype_equal (value .categories .dtype , self .dtype ):
789
+ # TODO: do we need equal dtype or just comparable?
790
+ value = value ._internal_get_values ()
791
+
792
+ if allow_object and is_object_dtype (value .dtype ):
793
+ pass
794
+
795
+ elif not type (self )._is_recognized_dtype (value .dtype ):
796
+ raise TypeError (
797
+ f"{ opname } requires compatible dtype or scalar, "
798
+ f"not { type (value ).__name__ } "
799
+ )
800
+
801
+ return value
802
+
773
803
def _validate_searchsorted_value (self , value ):
774
804
if isinstance (value , str ):
775
805
try :
@@ -785,41 +815,19 @@ def _validate_searchsorted_value(self, value):
785
815
elif isinstance (value , self ._recognized_scalars ):
786
816
value = self ._scalar_type (value )
787
817
788
- elif isinstance (value , type (self )):
789
- pass
790
-
791
- elif is_list_like (value ) and not isinstance (value , type (self )):
792
- value = array (value )
793
-
794
- if not type (self )._is_recognized_dtype (value .dtype ):
795
- raise TypeError (
796
- "searchsorted requires compatible dtype or scalar, "
797
- f"not { type (value ).__name__ } "
798
- )
818
+ elif not is_list_like (value ):
819
+ raise TypeError (f"Unexpected type for 'value': { type (value )} " )
799
820
800
821
else :
801
- raise TypeError (f"Unexpected type for 'value': { type (value )} " )
822
+ # TODO: cast_str? we accept it for scalar
823
+ value = self ._validate_listlike (value , "searchsorted" )
802
824
803
825
return self ._unbox (value )
804
826
805
827
def _validate_setitem_value (self , value ):
806
828
807
829
if is_list_like (value ):
808
- value = array (value )
809
- if is_dtype_equal (value .dtype , "string" ):
810
- # We got a StringArray
811
- try :
812
- # TODO: Could use from_sequence_of_strings if implemented
813
- # Note: passing dtype is necessary for PeriodArray tests
814
- value = type (self )._from_sequence (value , dtype = self .dtype )
815
- except ValueError :
816
- pass
817
-
818
- if not type (self )._is_recognized_dtype (value .dtype ):
819
- raise TypeError (
820
- "setitem requires compatible dtype or scalar, "
821
- f"not { type (value ).__name__ } "
822
- )
830
+ value = self ._validate_listlike (value , "setitem" , cast_str = True )
823
831
824
832
elif isinstance (value , self ._recognized_scalars ):
825
833
value = self ._scalar_type (value )
@@ -860,18 +868,8 @@ def _validate_where_value(self, other):
860
868
raise TypeError (f"Where requires matching dtype, not { type (other )} " )
861
869
862
870
else :
863
- # Do type inference if necessary up front
864
- # e.g. we passed PeriodIndex.values and got an ndarray of Periods
865
- other = array (other )
866
- other = extract_array (other , extract_numpy = True )
867
-
868
- if is_categorical_dtype (other .dtype ):
869
- # e.g. we have a Categorical holding self.dtype
870
- if is_dtype_equal (other .categories .dtype , self .dtype ):
871
- other = other ._internal_get_values ()
872
-
873
- if not type (self )._is_recognized_dtype (other .dtype ):
874
- raise TypeError (f"Where requires matching dtype, not { other .dtype } " )
871
+ other = self ._validate_listlike (other , "where" , cast_cat = True )
872
+ self ._check_compatible_with (other , setitem = True )
875
873
876
874
self ._check_compatible_with (other , setitem = True )
877
875
return self ._unbox (other )
0 commit comments