Skip to content

REF: collect validator methods for DTA/TDA/PA #33689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 133 additions & 81 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,6 @@ def __setitem__(
# to a period in from_sequence). For DatetimeArray, it's Timestamp...
# I don't know if mypy can do that, possibly with Generics.
# https://mypy.readthedocs.io/en/latest/generics.html
if lib.is_scalar(value) and not isna(value):
value = com.maybe_box_datetimelike(value)

if is_list_like(value):
is_slice = isinstance(key, slice)

Expand All @@ -609,21 +606,7 @@ def __setitem__(
elif not len(key):
return

value = type(self)._from_sequence(value, dtype=self.dtype)
self._check_compatible_with(value, setitem=True)
value = value.asi8
elif isinstance(value, self._scalar_type):
self._check_compatible_with(value, setitem=True)
value = self._unbox_scalar(value)
elif is_valid_nat_for_dtype(value, self.dtype):
value = iNaT
else:
msg = (
f"'value' should be a '{self._scalar_type.__name__}', 'NaT', "
f"or array of those. Got '{type(value).__name__}' instead."
)
raise TypeError(msg)

value = self._validate_setitem_value(value)
key = check_array_indexer(self, key)
self._data[key] = value
self._maybe_clear_freq()
Expand Down Expand Up @@ -682,35 +665,6 @@ def unique(self):
result = unique1d(self.asi8)
return type(self)(result, dtype=self.dtype)

def _validate_fill_value(self, fill_value):
"""
If a fill_value is passed to `take` convert it to an i8 representation,
raising ValueError if this is not possible.

Parameters
----------
fill_value : object

Returns
-------
fill_value : np.int64

Raises
------
ValueError
"""
if isna(fill_value):
fill_value = iNaT
elif isinstance(fill_value, self._recognized_scalars):
self._check_compatible_with(fill_value)
fill_value = self._scalar_type(fill_value)
fill_value = self._unbox_scalar(fill_value)
else:
raise ValueError(
f"'fill_value' should be a {self._scalar_type}. Got '{fill_value}'."
)
return fill_value

def take(self, indices, allow_fill=False, fill_value=None):
if allow_fill:
fill_value = self._validate_fill_value(fill_value)
Expand Down Expand Up @@ -769,6 +723,45 @@ def shift(self, periods=1, fill_value=None, axis=0):
if not self.size or periods == 0:
return self.copy()

fill_value = self._validate_shift_value(fill_value)
new_values = shift(self._data, periods, axis, fill_value)

return type(self)._simple_new(new_values, dtype=self.dtype)

# ------------------------------------------------------------------
# Validation Methods
# TODO: try to de-duplicate these, ensure identical behavior

def _validate_fill_value(self, fill_value):
"""
If a fill_value is passed to `take` convert it to an i8 representation,
raising ValueError if this is not possible.

Parameters
----------
fill_value : object

Returns
-------
fill_value : np.int64

Raises
------
ValueError
"""
if isna(fill_value):
fill_value = iNaT
elif isinstance(fill_value, self._recognized_scalars):
self._check_compatible_with(fill_value)
fill_value = self._scalar_type(fill_value)
fill_value = self._unbox_scalar(fill_value)
else:
raise ValueError(
f"'fill_value' should be a {self._scalar_type}. Got '{fill_value}'."
)
return fill_value

def _validate_shift_value(self, fill_value):
# TODO(2.0): once this deprecation is enforced, used _validate_fill_value
if is_valid_nat_for_dtype(fill_value, self.dtype):
fill_value = NaT
Expand All @@ -787,15 +780,104 @@ def shift(self, periods=1, fill_value=None, axis=0):
"will raise in a future version, pass "
f"{self._scalar_type.__name__} instead.",
FutureWarning,
stacklevel=9,
stacklevel=10,
)
fill_value = new_fill

fill_value = self._unbox_scalar(fill_value)
return fill_value

new_values = shift(self._data, periods, axis, fill_value)
def _validate_searchsorted_value(self, value):
if isinstance(value, str):
try:
value = self._scalar_from_string(value)
except ValueError as err:
raise TypeError(
"searchsorted requires compatible dtype or scalar"
) from err

return type(self)._simple_new(new_values, dtype=self.dtype)
elif is_valid_nat_for_dtype(value, self.dtype):
value = NaT

elif isinstance(value, self._recognized_scalars):
value = self._scalar_type(value)

elif is_list_like(value) and not isinstance(value, type(self)):
value = array(value)

if not type(self)._is_recognized_dtype(value):
raise TypeError(
"searchsorted requires compatible dtype or scalar, "
f"not {type(value).__name__}"
)

if not (isinstance(value, (self._scalar_type, type(self))) or (value is NaT)):
raise TypeError(f"Unexpected type for 'value': {type(value)}")

if isinstance(value, type(self)):
self._check_compatible_with(value)
value = value.asi8
else:
value = self._unbox_scalar(value)

return value

def _validate_setitem_value(self, value):
if lib.is_scalar(value) and not isna(value):
value = com.maybe_box_datetimelike(value)

if is_list_like(value):
value = type(self)._from_sequence(value, dtype=self.dtype)
self._check_compatible_with(value, setitem=True)
value = value.asi8
elif isinstance(value, self._scalar_type):
self._check_compatible_with(value, setitem=True)
value = self._unbox_scalar(value)
elif is_valid_nat_for_dtype(value, self.dtype):
value = iNaT
else:
msg = (
f"'value' should be a '{self._scalar_type.__name__}', 'NaT', "
f"or array of those. Got '{type(value).__name__}' instead."
)
raise TypeError(msg)

return value

def _validate_insert_value(self, value):
if isinstance(value, self._recognized_scalars):
value = self._scalar_type(value)
elif is_valid_nat_for_dtype(value, self.dtype):
# GH#18295
value = NaT
elif lib.is_scalar(value) and isna(value):
raise TypeError(
f"cannot insert {type(self).__name__} with incompatible label"
)

return value

def _validate_where_value(self, other):
if lib.is_scalar(other) and isna(other):
other = NaT.value

else:
# Do type inference if necessary up front
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
from pandas import Index

other = Index(other)

if is_categorical_dtype(other):
# e.g. we have a Categorical holding self.dtype
if is_dtype_equal(other.categories.dtype, self.dtype):
other = other._internal_get_values()

if not is_dtype_equal(self.dtype, other.dtype):
raise TypeError(f"Where requires matching dtype, not {other.dtype}")

other = other.view("i8")
return other

# ------------------------------------------------------------------
# Additional array methods
Expand Down Expand Up @@ -827,37 +909,7 @@ def searchsorted(self, value, side="left", sorter=None):
indices : array of ints
Array of insertion points with the same shape as `value`.
"""
if isinstance(value, str):
try:
value = self._scalar_from_string(value)
except ValueError as e:
raise TypeError(
"searchsorted requires compatible dtype or scalar"
) from e

elif is_valid_nat_for_dtype(value, self.dtype):
value = NaT

elif isinstance(value, self._recognized_scalars):
value = self._scalar_type(value)

elif is_list_like(value) and not isinstance(value, type(self)):
value = array(value)

if not type(self)._is_recognized_dtype(value):
raise TypeError(
"searchsorted requires compatible dtype or scalar, "
f"not {type(value).__name__}"
)

if not (isinstance(value, (self._scalar_type, type(self))) or (value is NaT)):
raise TypeError(f"Unexpected type for 'value': {type(value)}")

if isinstance(value, type(self)):
self._check_compatible_with(value)
value = value.asi8
else:
value = self._unbox_scalar(value)
value = self._validate_searchsorted_value(value)

# TODO: Use datetime64 semantics for sorting, xref GH#29844
return self.asi8.searchsorted(value, side=side, sorter=sorter)
Expand Down
30 changes: 2 additions & 28 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
ensure_int64,
ensure_platform_int,
is_bool_dtype,
is_categorical_dtype,
is_dtype_equal,
is_integer,
is_list_like,
Expand All @@ -26,7 +25,6 @@
)
from pandas.core.dtypes.concat import concat_compat
from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna

from pandas.core import algorithms
from pandas.core.arrays import DatetimeArray, PeriodArray, TimedeltaArray
Expand Down Expand Up @@ -494,23 +492,7 @@ def isin(self, values, level=None):
def where(self, cond, other=None):
values = self.view("i8")

if is_scalar(other) and isna(other):
other = NaT.value

else:
# Do type inference if necessary up front
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
other = Index(other)

if is_categorical_dtype(other):
# e.g. we have a Categorical holding self.dtype
if is_dtype_equal(other.categories.dtype, self.dtype):
other = other._internal_get_values()

if not is_dtype_equal(self.dtype, other.dtype):
raise TypeError(f"Where requires matching dtype, not {other.dtype}")

other = other.view("i8")
other = self._data._validate_where_value(other)

result = np.where(cond, values, other).astype("i8")
arr = type(self._data)._simple_new(result, dtype=self.dtype)
Expand Down Expand Up @@ -923,15 +905,7 @@ def insert(self, loc, item):
-------
new_index : Index
"""
if isinstance(item, self._data._recognized_scalars):
item = self._data._scalar_type(item)
elif is_valid_nat_for_dtype(item, self.dtype):
# GH 18295
item = self._na_value
elif is_scalar(item) and isna(item):
raise TypeError(
f"cannot insert {type(self).__name__} with incompatible label"
)
item = self._data._validate_insert_value(item)

freq = None
if isinstance(item, self._data._scalar_type) or item is NaT:
Expand Down