Skip to content

Commit cce0b85

Browse files
jbrockmendelrhshadrach
authored andcommitted
REF: collect validator methods (pandas-dev#33689)
1 parent 29276d6 commit cce0b85

File tree

2 files changed

+135
-109
lines changed

2 files changed

+135
-109
lines changed

pandas/core/arrays/datetimelike.py

+133-81
Original file line numberDiff line numberDiff line change
@@ -588,9 +588,6 @@ def __setitem__(
588588
# to a period in from_sequence). For DatetimeArray, it's Timestamp...
589589
# I don't know if mypy can do that, possibly with Generics.
590590
# https://mypy.readthedocs.io/en/latest/generics.html
591-
if lib.is_scalar(value) and not isna(value):
592-
value = com.maybe_box_datetimelike(value)
593-
594591
if is_list_like(value):
595592
is_slice = isinstance(key, slice)
596593

@@ -609,21 +606,7 @@ def __setitem__(
609606
elif not len(key):
610607
return
611608

612-
value = type(self)._from_sequence(value, dtype=self.dtype)
613-
self._check_compatible_with(value, setitem=True)
614-
value = value.asi8
615-
elif isinstance(value, self._scalar_type):
616-
self._check_compatible_with(value, setitem=True)
617-
value = self._unbox_scalar(value)
618-
elif is_valid_nat_for_dtype(value, self.dtype):
619-
value = iNaT
620-
else:
621-
msg = (
622-
f"'value' should be a '{self._scalar_type.__name__}', 'NaT', "
623-
f"or array of those. Got '{type(value).__name__}' instead."
624-
)
625-
raise TypeError(msg)
626-
609+
value = self._validate_setitem_value(value)
627610
key = check_array_indexer(self, key)
628611
self._data[key] = value
629612
self._maybe_clear_freq()
@@ -682,35 +665,6 @@ def unique(self):
682665
result = unique1d(self.asi8)
683666
return type(self)(result, dtype=self.dtype)
684667

685-
def _validate_fill_value(self, fill_value):
686-
"""
687-
If a fill_value is passed to `take` convert it to an i8 representation,
688-
raising ValueError if this is not possible.
689-
690-
Parameters
691-
----------
692-
fill_value : object
693-
694-
Returns
695-
-------
696-
fill_value : np.int64
697-
698-
Raises
699-
------
700-
ValueError
701-
"""
702-
if isna(fill_value):
703-
fill_value = iNaT
704-
elif isinstance(fill_value, self._recognized_scalars):
705-
self._check_compatible_with(fill_value)
706-
fill_value = self._scalar_type(fill_value)
707-
fill_value = self._unbox_scalar(fill_value)
708-
else:
709-
raise ValueError(
710-
f"'fill_value' should be a {self._scalar_type}. Got '{fill_value}'."
711-
)
712-
return fill_value
713-
714668
def take(self, indices, allow_fill=False, fill_value=None):
715669
if allow_fill:
716670
fill_value = self._validate_fill_value(fill_value)
@@ -769,6 +723,45 @@ def shift(self, periods=1, fill_value=None, axis=0):
769723
if not self.size or periods == 0:
770724
return self.copy()
771725

726+
fill_value = self._validate_shift_value(fill_value)
727+
new_values = shift(self._data, periods, axis, fill_value)
728+
729+
return type(self)._simple_new(new_values, dtype=self.dtype)
730+
731+
# ------------------------------------------------------------------
732+
# Validation Methods
733+
# TODO: try to de-duplicate these, ensure identical behavior
734+
735+
def _validate_fill_value(self, fill_value):
736+
"""
737+
If a fill_value is passed to `take` convert it to an i8 representation,
738+
raising ValueError if this is not possible.
739+
740+
Parameters
741+
----------
742+
fill_value : object
743+
744+
Returns
745+
-------
746+
fill_value : np.int64
747+
748+
Raises
749+
------
750+
ValueError
751+
"""
752+
if isna(fill_value):
753+
fill_value = iNaT
754+
elif isinstance(fill_value, self._recognized_scalars):
755+
self._check_compatible_with(fill_value)
756+
fill_value = self._scalar_type(fill_value)
757+
fill_value = self._unbox_scalar(fill_value)
758+
else:
759+
raise ValueError(
760+
f"'fill_value' should be a {self._scalar_type}. Got '{fill_value}'."
761+
)
762+
return fill_value
763+
764+
def _validate_shift_value(self, fill_value):
772765
# TODO(2.0): once this deprecation is enforced, used _validate_fill_value
773766
if is_valid_nat_for_dtype(fill_value, self.dtype):
774767
fill_value = NaT
@@ -787,15 +780,104 @@ def shift(self, periods=1, fill_value=None, axis=0):
787780
"will raise in a future version, pass "
788781
f"{self._scalar_type.__name__} instead.",
789782
FutureWarning,
790-
stacklevel=9,
783+
stacklevel=10,
791784
)
792785
fill_value = new_fill
793786

794787
fill_value = self._unbox_scalar(fill_value)
788+
return fill_value
795789

796-
new_values = shift(self._data, periods, axis, fill_value)
790+
def _validate_searchsorted_value(self, value):
791+
if isinstance(value, str):
792+
try:
793+
value = self._scalar_from_string(value)
794+
except ValueError as err:
795+
raise TypeError(
796+
"searchsorted requires compatible dtype or scalar"
797+
) from err
797798

798-
return type(self)._simple_new(new_values, dtype=self.dtype)
799+
elif is_valid_nat_for_dtype(value, self.dtype):
800+
value = NaT
801+
802+
elif isinstance(value, self._recognized_scalars):
803+
value = self._scalar_type(value)
804+
805+
elif is_list_like(value) and not isinstance(value, type(self)):
806+
value = array(value)
807+
808+
if not type(self)._is_recognized_dtype(value):
809+
raise TypeError(
810+
"searchsorted requires compatible dtype or scalar, "
811+
f"not {type(value).__name__}"
812+
)
813+
814+
if not (isinstance(value, (self._scalar_type, type(self))) or (value is NaT)):
815+
raise TypeError(f"Unexpected type for 'value': {type(value)}")
816+
817+
if isinstance(value, type(self)):
818+
self._check_compatible_with(value)
819+
value = value.asi8
820+
else:
821+
value = self._unbox_scalar(value)
822+
823+
return value
824+
825+
def _validate_setitem_value(self, value):
826+
if lib.is_scalar(value) and not isna(value):
827+
value = com.maybe_box_datetimelike(value)
828+
829+
if is_list_like(value):
830+
value = type(self)._from_sequence(value, dtype=self.dtype)
831+
self._check_compatible_with(value, setitem=True)
832+
value = value.asi8
833+
elif isinstance(value, self._scalar_type):
834+
self._check_compatible_with(value, setitem=True)
835+
value = self._unbox_scalar(value)
836+
elif is_valid_nat_for_dtype(value, self.dtype):
837+
value = iNaT
838+
else:
839+
msg = (
840+
f"'value' should be a '{self._scalar_type.__name__}', 'NaT', "
841+
f"or array of those. Got '{type(value).__name__}' instead."
842+
)
843+
raise TypeError(msg)
844+
845+
return value
846+
847+
def _validate_insert_value(self, value):
848+
if isinstance(value, self._recognized_scalars):
849+
value = self._scalar_type(value)
850+
elif is_valid_nat_for_dtype(value, self.dtype):
851+
# GH#18295
852+
value = NaT
853+
elif lib.is_scalar(value) and isna(value):
854+
raise TypeError(
855+
f"cannot insert {type(self).__name__} with incompatible label"
856+
)
857+
858+
return value
859+
860+
def _validate_where_value(self, other):
861+
if lib.is_scalar(other) and isna(other):
862+
other = NaT.value
863+
864+
else:
865+
# Do type inference if necessary up front
866+
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
867+
from pandas import Index
868+
869+
other = Index(other)
870+
871+
if is_categorical_dtype(other):
872+
# e.g. we have a Categorical holding self.dtype
873+
if is_dtype_equal(other.categories.dtype, self.dtype):
874+
other = other._internal_get_values()
875+
876+
if not is_dtype_equal(self.dtype, other.dtype):
877+
raise TypeError(f"Where requires matching dtype, not {other.dtype}")
878+
879+
other = other.view("i8")
880+
return other
799881

800882
# ------------------------------------------------------------------
801883
# Additional array methods
@@ -827,37 +909,7 @@ def searchsorted(self, value, side="left", sorter=None):
827909
indices : array of ints
828910
Array of insertion points with the same shape as `value`.
829911
"""
830-
if isinstance(value, str):
831-
try:
832-
value = self._scalar_from_string(value)
833-
except ValueError as e:
834-
raise TypeError(
835-
"searchsorted requires compatible dtype or scalar"
836-
) from e
837-
838-
elif is_valid_nat_for_dtype(value, self.dtype):
839-
value = NaT
840-
841-
elif isinstance(value, self._recognized_scalars):
842-
value = self._scalar_type(value)
843-
844-
elif is_list_like(value) and not isinstance(value, type(self)):
845-
value = array(value)
846-
847-
if not type(self)._is_recognized_dtype(value):
848-
raise TypeError(
849-
"searchsorted requires compatible dtype or scalar, "
850-
f"not {type(value).__name__}"
851-
)
852-
853-
if not (isinstance(value, (self._scalar_type, type(self))) or (value is NaT)):
854-
raise TypeError(f"Unexpected type for 'value': {type(value)}")
855-
856-
if isinstance(value, type(self)):
857-
self._check_compatible_with(value)
858-
value = value.asi8
859-
else:
860-
value = self._unbox_scalar(value)
912+
value = self._validate_searchsorted_value(value)
861913

862914
# TODO: Use datetime64 semantics for sorting, xref GH#29844
863915
return self.asi8.searchsorted(value, side=side, sorter=sorter)

pandas/core/indexes/datetimelike.py

+2-28
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
ensure_int64,
1818
ensure_platform_int,
1919
is_bool_dtype,
20-
is_categorical_dtype,
2120
is_dtype_equal,
2221
is_integer,
2322
is_list_like,
@@ -26,7 +25,6 @@
2625
)
2726
from pandas.core.dtypes.concat import concat_compat
2827
from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries
29-
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna
3028

3129
from pandas.core import algorithms
3230
from pandas.core.arrays import DatetimeArray, PeriodArray, TimedeltaArray
@@ -494,23 +492,7 @@ def isin(self, values, level=None):
494492
def where(self, cond, other=None):
495493
values = self.view("i8")
496494

497-
if is_scalar(other) and isna(other):
498-
other = NaT.value
499-
500-
else:
501-
# Do type inference if necessary up front
502-
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
503-
other = Index(other)
504-
505-
if is_categorical_dtype(other):
506-
# e.g. we have a Categorical holding self.dtype
507-
if is_dtype_equal(other.categories.dtype, self.dtype):
508-
other = other._internal_get_values()
509-
510-
if not is_dtype_equal(self.dtype, other.dtype):
511-
raise TypeError(f"Where requires matching dtype, not {other.dtype}")
512-
513-
other = other.view("i8")
495+
other = self._data._validate_where_value(other)
514496

515497
result = np.where(cond, values, other).astype("i8")
516498
arr = type(self._data)._simple_new(result, dtype=self.dtype)
@@ -923,15 +905,7 @@ def insert(self, loc, item):
923905
-------
924906
new_index : Index
925907
"""
926-
if isinstance(item, self._data._recognized_scalars):
927-
item = self._data._scalar_type(item)
928-
elif is_valid_nat_for_dtype(item, self.dtype):
929-
# GH 18295
930-
item = self._na_value
931-
elif is_scalar(item) and isna(item):
932-
raise TypeError(
933-
f"cannot insert {type(self).__name__} with incompatible label"
934-
)
908+
item = self._data._validate_insert_value(item)
935909

936910
freq = None
937911
if isinstance(item, self._data._scalar_type) or item is NaT:

0 commit comments

Comments
 (0)