Skip to content

REF: de-duplicate _validate_insert_value with _validate_scalar #37640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _box_func(self, x):
"""
return x

def _validate_insert_value(self, value):
def _validate_scalar(self, value):
# used by NDArrayBackedExtensionIndex.insert
raise AbstractMethodError(self)

Expand Down
5 changes: 2 additions & 3 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,9 +1177,6 @@ def map(self, mapper):
# -------------------------------------------------------------
# Validators; ideally these can be de-duplicated

def _validate_insert_value(self, value) -> int:
return self._validate_fill_value(value)

def _validate_searchsorted_value(self, value):
# searchsorted is very performance sensitive. By converting codes
# to same dtype as self.codes, we get much faster performance.
Expand Down Expand Up @@ -1219,6 +1216,8 @@ def _validate_fill_value(self, fill_value):
)
return fill_value

_validate_scalar = _validate_fill_value

# -------------------------------------------------------------

def __array__(self, dtype=None) -> np.ndarray:
Expand Down
36 changes: 25 additions & 11 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,12 @@ def _validate_fill_value(self, fill_value):
f"Got '{str(fill_value)}'."
)
try:
fill_value = self._validate_scalar(fill_value)
return self._validate_scalar(fill_value)
except TypeError as err:
if "Cannot compare tz-naive and tz-aware" in str(err):
# tzawareness-compat
raise
raise ValueError(msg) from err
return self._unbox(fill_value, setitem=True)

def _validate_shift_value(self, fill_value):
# TODO(2.0): once this deprecation is enforced, use _validate_fill_value
Expand Down Expand Up @@ -511,7 +513,14 @@ def _validate_shift_value(self, fill_value):

return self._unbox(fill_value, setitem=True)

def _validate_scalar(self, value, allow_listlike: bool = False):
def _validate_scalar(
self,
value,
*,
allow_listlike: bool = False,
setitem: bool = True,
unbox: bool = True,
):
"""
Validate that the input value can be cast to our scalar_type.

Expand All @@ -521,6 +530,11 @@ def _validate_scalar(self, value, allow_listlike: bool = False):
allow_listlike: bool, default False
When raising an exception, whether the message should say
listlike inputs are allowed.
setitem : bool, default True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't love the semantics here, e.g. unbox bypassses another option but understand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe worth have a separate method for the unboxing to avoide this parameter conflation (and share the other logic)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't love the semantics here, e.g. unbox bypassses another option but understand.

yah i didnt like it much either, but in the benchmarks i ran TDI.get_loc was 5x slower if i didnt skip the unboxing.

I've got a couple of ideas kicking around to make this nicer. The only one that has anything on GH ATM is #37605 which still needs discussion.

Whether to check compatibility with setitem strictness.
unbox : bool, default True
Whether to unbox the result before returning. Note: unbox=False
skips the setitem compatibility check.

Returns
-------
Expand All @@ -546,7 +560,12 @@ def _validate_scalar(self, value, allow_listlike: bool = False):
msg = self._validation_error_message(value, allow_listlike)
raise TypeError(msg)

return value
if not unbox:
# NB: In general NDArrayBackedExtensionArray will unbox here;
# this option exists to prevent a performance hit in
# TimedeltaIndex.get_loc
return value
return self._unbox_scalar(value, setitem=setitem)

def _validation_error_message(self, value, allow_listlike: bool = False) -> str:
"""
Expand Down Expand Up @@ -611,7 +630,7 @@ def _validate_listlike(self, value, allow_object: bool = False):

def _validate_searchsorted_value(self, value):
if not is_list_like(value):
value = self._validate_scalar(value, True)
return self._validate_scalar(value, allow_listlike=True, setitem=False)
else:
value = self._validate_listlike(value)

Expand All @@ -621,12 +640,7 @@ def _validate_setitem_value(self, value):
if is_list_like(value):
value = self._validate_listlike(value)
else:
value = self._validate_scalar(value, True)

return self._unbox(value, setitem=True)

def _validate_insert_value(self, value):
value = self._validate_scalar(value)
return self._validate_scalar(value, allow_listlike=True)

return self._unbox(value, setitem=True)

Expand Down
3 changes: 0 additions & 3 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,6 @@ def _validate_fillna_value(self, value):
)
raise TypeError(msg) from err

def _validate_insert_value(self, value):
return self._validate_scalar(value)

def _validate_setitem_value(self, value):
needs_float_conversion = False

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,7 +2292,7 @@ def fillna(self, value=None, downcast=None):
DataFrame.fillna : Fill NaN values of a DataFrame.
Series.fillna : Fill NaN Values of a Series.
"""
value = self._validate_scalar(value)
value = self._require_scalar(value)
if self.hasnans:
result = self.putmask(self._isnan, value)
if downcast is None:
Expand Down Expand Up @@ -4140,7 +4140,7 @@ def _validate_fill_value(self, value):
return value

@final
def _validate_scalar(self, value):
def _require_scalar(self, value):
"""
Check that this is a scalar value that we can use for setitem-like
operations without changing dtype.
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def astype(self, dtype, copy=True):

@doc(Index.fillna)
def fillna(self, value, downcast=None):
value = self._validate_scalar(value)
value = self._require_scalar(value)
cat = self._data.fillna(value)
return type(self)._simple_new(cat, name=self.name)

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def _get_insert_freq(self, loc, item):
"""
Find the `freq` for self.insert(loc, item).
"""
value = self._data._validate_insert_value(item)
value = self._data._validate_scalar(item)
item = self._data._box_func(value)

freq = None
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def insert(self, loc: int, item):
ValueError if the item is not valid for this dtype.
"""
arr = self._data
code = arr._validate_insert_value(item)
code = arr._validate_scalar(item)

new_vals = np.concatenate((arr._ndarray[:loc], [code], arr._ndarray[loc:]))
new_arr = arr._from_backing_data(new_vals)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ def insert(self, loc, item):
-------
IntervalIndex
"""
left_insert, right_insert = self._data._validate_insert_value(item)
left_insert, right_insert = self._data._validate_scalar(item)

new_left = self.left.insert(loc, left_insert)
new_right = self.right.insert(loc, right_insert)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_loc(self, key, method=None, tolerance=None):
raise InvalidIndexError(key)

try:
key = self._data._validate_scalar(key)
key = self._data._validate_scalar(key, unbox=False)
except TypeError as err:
raise KeyError(key) from err

Expand Down