Skip to content

REF: de-duplicate listlike validation in DTA._validate_foo #33908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 6, 2020
116 changes: 57 additions & 59 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime, timedelta
import operator
from typing import Any, Sequence, Type, TypeVar, Union, cast
from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast
import warnings

import numpy as np
Expand All @@ -10,7 +10,7 @@
from pandas._libs.tslibs.period import DIFFERENT_FREQ, IncompatibleFrequency, Period
from pandas._libs.tslibs.timedeltas import delta_to_nanoseconds
from pandas._libs.tslibs.timestamps import RoundTo, round_nsint64
from pandas._typing import DatetimeLikeScalar
from pandas._typing import DatetimeLikeScalar, DtypeObj
from pandas.compat import set_function_name
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning
Expand Down Expand Up @@ -86,24 +86,10 @@ def _validate_comparison_value(self, other):
raise ValueError("Lengths must match")

else:
if isinstance(other, list):
# TODO: could use pd.Index to do inference?
other = np.array(other)

if not isinstance(other, (np.ndarray, type(self))):
raise InvalidComparison(other)

elif is_object_dtype(other.dtype):
pass

elif not type(self)._is_recognized_dtype(other.dtype):
raise InvalidComparison(other)

else:
# For PeriodDType this casting is unnecessary
# TODO: use Index to do inference?
other = type(self)._from_sequence(other)
self._check_compatible_with(other)
try:
other = self._validate_listlike(other, opname, allow_object=True)
except TypeError as err:
raise InvalidComparison(other) from err

return other

Expand Down Expand Up @@ -451,6 +437,8 @@ class DatetimeLikeArrayMixin(
_generate_range
"""

_is_recognized_dtype: Callable[[DtypeObj], bool]

# ------------------------------------------------------------------
# NDArrayBackedExtensionArray compat

Expand Down Expand Up @@ -761,6 +749,48 @@ def _validate_shift_value(self, fill_value):

return self._unbox(fill_value)

def _validate_listlike(
self,
value,
opname: str,
cast_str: bool = False,
cast_cat: bool = False,
allow_object: bool = False,
):
if isinstance(value, type(self)):
return value

# Do type inference if necessary up front
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
value = array(value)
value = extract_array(value, extract_numpy=True)

if cast_str and is_dtype_equal(value.dtype, "string"):
# We got a StringArray
try:
# TODO: Could use from_sequence_of_strings if implemented
# Note: passing dtype is necessary for PeriodArray tests
value = type(self)._from_sequence(value, dtype=self.dtype)
except ValueError:
pass

if cast_cat and is_categorical_dtype(value.dtype):
# e.g. we have a Categorical holding self.dtype
if is_dtype_equal(value.categories.dtype, self.dtype):
# TODO: do we need equal dtype or just comparable?
value = value._internal_get_values()

if allow_object and is_object_dtype(value.dtype):
pass

elif not type(self)._is_recognized_dtype(value.dtype):
raise TypeError(
f"{opname} requires compatible dtype or scalar, "
f"not {type(value).__name__}"
)

return value

def _validate_searchsorted_value(self, value):
if isinstance(value, str):
try:
Expand All @@ -776,41 +806,19 @@ def _validate_searchsorted_value(self, value):
elif isinstance(value, self._recognized_scalars):
value = self._scalar_type(value)

elif isinstance(value, type(self)):
pass

elif is_list_like(value) and not isinstance(value, type(self)):
value = array(value)

if not type(self)._is_recognized_dtype(value.dtype):
raise TypeError(
"searchsorted requires compatible dtype or scalar, "
f"not {type(value).__name__}"
)
elif not is_list_like(value):
raise TypeError(f"Unexpected type for 'value': {type(value)}")

else:
raise TypeError(f"Unexpected type for 'value': {type(value)}")
# TODO: cast_str? we accept it for scalar
value = self._validate_listlike(value, "searchsorted")

return self._unbox(value)

def _validate_setitem_value(self, value):

if is_list_like(value):
value = array(value)
if is_dtype_equal(value.dtype, "string"):
# We got a StringArray
try:
# TODO: Could use from_sequence_of_strings if implemented
# Note: passing dtype is necessary for PeriodArray tests
value = type(self)._from_sequence(value, dtype=self.dtype)
except ValueError:
pass

if not type(self)._is_recognized_dtype(value.dtype):
raise TypeError(
"setitem requires compatible dtype or scalar, "
f"not {type(value).__name__}"
)
value = self._validate_listlike(value, "setitem", cast_str=True)

elif isinstance(value, self._recognized_scalars):
value = self._scalar_type(value)
Expand Down Expand Up @@ -851,18 +859,8 @@ def _validate_where_value(self, other):
raise TypeError(f"Where requires matching dtype, not {type(other)}")

else:
# Do type inference if necessary up front
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
other = array(other)
other = extract_array(other, extract_numpy=True)

if is_categorical_dtype(other.dtype):
# e.g. we have a Categorical holding self.dtype
if is_dtype_equal(other.categories.dtype, self.dtype):
other = other._internal_get_values()

if not type(self)._is_recognized_dtype(other.dtype):
raise TypeError(f"Where requires matching dtype, not {other.dtype}")
other = self._validate_listlike(other, "where", cast_cat=True)
self._check_compatible_with(other, setitem=True)

self._check_compatible_with(other, setitem=True)
return self._unbox(other)
Expand Down