From 68d08fcd32d33cfee76971ef832300e6cbf1d0b0 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 4 Nov 2020 19:12:38 -0800 Subject: [PATCH] REF: de-duplicate _validate_insert_value with _validate_scalar --- pandas/core/arrays/_mixins.py | 2 +- pandas/core/arrays/categorical.py | 5 ++-- pandas/core/arrays/datetimelike.py | 36 ++++++++++++++++++++--------- pandas/core/arrays/interval.py | 3 --- pandas/core/indexes/base.py | 4 ++-- pandas/core/indexes/category.py | 2 +- pandas/core/indexes/datetimelike.py | 2 +- pandas/core/indexes/extension.py | 2 +- pandas/core/indexes/interval.py | 2 +- pandas/core/indexes/timedeltas.py | 2 +- 10 files changed, 35 insertions(+), 25 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 67ac2a3688214..63c414d96c8de 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -45,7 +45,7 @@ def _box_func(self, x): """ return x - def _validate_insert_value(self, value): + def _validate_scalar(self, value): # used by NDArrayBackedExtensionIndex.insert raise AbstractMethodError(self) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 626fb495dec03..edbf24ca87f5c 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -1177,9 +1177,6 @@ def map(self, mapper): # ------------------------------------------------------------- # Validators; ideally these can be de-duplicated - def _validate_insert_value(self, value) -> int: - return self._validate_fill_value(value) - def _validate_searchsorted_value(self, value): # searchsorted is very performance sensitive. By converting codes # to same dtype as self.codes, we get much faster performance. @@ -1219,6 +1216,8 @@ def _validate_fill_value(self, fill_value): ) return fill_value + _validate_scalar = _validate_fill_value + # ------------------------------------------------------------- def __array__(self, dtype=None) -> np.ndarray: diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 404511895ddf0..7a0d88f29b9b0 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -479,10 +479,12 @@ def _validate_fill_value(self, fill_value): f"Got '{str(fill_value)}'." ) try: - fill_value = self._validate_scalar(fill_value) + return self._validate_scalar(fill_value) except TypeError as err: + if "Cannot compare tz-naive and tz-aware" in str(err): + # tzawareness-compat + raise raise ValueError(msg) from err - return self._unbox(fill_value, setitem=True) def _validate_shift_value(self, fill_value): # TODO(2.0): once this deprecation is enforced, use _validate_fill_value @@ -511,7 +513,14 @@ def _validate_shift_value(self, fill_value): return self._unbox(fill_value, setitem=True) - def _validate_scalar(self, value, allow_listlike: bool = False): + def _validate_scalar( + self, + value, + *, + allow_listlike: bool = False, + setitem: bool = True, + unbox: bool = True, + ): """ Validate that the input value can be cast to our scalar_type. @@ -521,6 +530,11 @@ def _validate_scalar(self, value, allow_listlike: bool = False): allow_listlike: bool, default False When raising an exception, whether the message should say listlike inputs are allowed. + setitem : bool, default True + Whether to check compatibility with setitem strictness. + unbox : bool, default True + Whether to unbox the result before returning. Note: unbox=False + skips the setitem compatibility check. Returns ------- @@ -546,7 +560,12 @@ def _validate_scalar(self, value, allow_listlike: bool = False): msg = self._validation_error_message(value, allow_listlike) raise TypeError(msg) - return value + if not unbox: + # NB: In general NDArrayBackedExtensionArray will unbox here; + # this option exists to prevent a performance hit in + # TimedeltaIndex.get_loc + return value + return self._unbox_scalar(value, setitem=setitem) def _validation_error_message(self, value, allow_listlike: bool = False) -> str: """ @@ -611,7 +630,7 @@ def _validate_listlike(self, value, allow_object: bool = False): def _validate_searchsorted_value(self, value): if not is_list_like(value): - value = self._validate_scalar(value, True) + return self._validate_scalar(value, allow_listlike=True, setitem=False) else: value = self._validate_listlike(value) @@ -621,12 +640,7 @@ def _validate_setitem_value(self, value): if is_list_like(value): value = self._validate_listlike(value) else: - value = self._validate_scalar(value, True) - - return self._unbox(value, setitem=True) - - def _validate_insert_value(self, value): - value = self._validate_scalar(value) + return self._validate_scalar(value, allow_listlike=True) return self._unbox(value, setitem=True) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index f8ece2a9fe7d4..7b10334804ef9 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -889,9 +889,6 @@ def _validate_fillna_value(self, value): ) raise TypeError(msg) from err - def _validate_insert_value(self, value): - return self._validate_scalar(value) - def _validate_setitem_value(self, value): needs_float_conversion = False diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 98ec3b55e65d9..f350e18198057 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2292,7 +2292,7 @@ def fillna(self, value=None, downcast=None): DataFrame.fillna : Fill NaN values of a DataFrame. Series.fillna : Fill NaN Values of a Series. """ - value = self._validate_scalar(value) + value = self._require_scalar(value) if self.hasnans: result = self.putmask(self._isnan, value) if downcast is None: @@ -4140,7 +4140,7 @@ def _validate_fill_value(self, value): return value @final - def _validate_scalar(self, value): + def _require_scalar(self, value): """ Check that this is a scalar value that we can use for setitem-like operations without changing dtype. diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 8cbd0d83c78d7..525c41bae8b51 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -382,7 +382,7 @@ def astype(self, dtype, copy=True): @doc(Index.fillna) def fillna(self, value, downcast=None): - value = self._validate_scalar(value) + value = self._require_scalar(value) cat = self._data.fillna(value) return type(self)._simple_new(cat, name=self.name) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 9e2ac6013cb43..2cb66557b3bab 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -581,7 +581,7 @@ def _get_insert_freq(self, loc, item): """ Find the `freq` for self.insert(loc, item). """ - value = self._data._validate_insert_value(item) + value = self._data._validate_scalar(item) item = self._data._box_func(value) freq = None diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index cd1871e4687f3..921c7aac2c85b 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -335,7 +335,7 @@ def insert(self, loc: int, item): ValueError if the item is not valid for this dtype. """ arr = self._data - code = arr._validate_insert_value(item) + code = arr._validate_scalar(item) new_vals = np.concatenate((arr._ndarray[:loc], [code], arr._ndarray[loc:])) new_arr = arr._from_backing_data(new_vals) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index c700acc24f411..2aec86c9cdfae 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -903,7 +903,7 @@ def insert(self, loc, item): ------- IntervalIndex """ - left_insert, right_insert = self._data._validate_insert_value(item) + left_insert, right_insert = self._data._validate_scalar(item) new_left = self.left.insert(loc, left_insert) new_right = self.right.insert(loc, right_insert) diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index 66fd6943de721..cf5fa4bbb3d75 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -215,7 +215,7 @@ def get_loc(self, key, method=None, tolerance=None): raise InvalidIndexError(key) try: - key = self._data._validate_scalar(key) + key = self._data._validate_scalar(key, unbox=False) except TypeError as err: raise KeyError(key) from err