Skip to content

REF: de-duplicate _validate_fill_value/_validate_scalar #41790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 2 additions & 21 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def take(
axis: int = 0,
) -> NDArrayBackedExtensionArrayT:
if allow_fill:
fill_value = self._validate_fill_value(fill_value)
fill_value = self._validate_scalar(fill_value)

new_data = take(
self._ndarray,
Expand All @@ -107,25 +107,6 @@ def take(
)
return self._from_backing_data(new_data)

def _validate_fill_value(self, fill_value):
"""
If a fill_value is passed to `take` convert it to a representation
suitable for self._ndarray, raising TypeError if this is not possible.

Parameters
----------
fill_value : object

Returns
-------
fill_value : native representation

Raises
------
TypeError
"""
raise AbstractMethodError(self)

# ------------------------------------------------------------------------

def equals(self, other) -> bool:
Expand Down Expand Up @@ -194,7 +175,7 @@ def shift(self, periods=1, fill_value=None, axis=0):
def _validate_shift_value(self, fill_value):
# TODO: after deprecation in datetimelikearraymixin is enforced,
# we can remove this and ust validate_fill_value directly
return self._validate_fill_value(fill_value)
return self._validate_scalar(fill_value)

def __setitem__(self, key, value):
key = check_array_indexer(self, key)
Expand Down
4 changes: 1 addition & 3 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,7 @@ def _validate_searchsorted_value(self, value):
codes = np.array(locs, dtype=self.codes.dtype) # type: ignore[assignment]
return codes

def _validate_fill_value(self, fill_value):
def _validate_scalar(self, fill_value):
"""
Convert a user-facing fill_value to a representation to use with our
underlying ndarray, raising TypeError if this is not possible.
Expand Down Expand Up @@ -1436,8 +1436,6 @@ def _validate_fill_value(self, fill_value):
)
return fill_value

_validate_scalar = _validate_fill_value

# -------------------------------------------------------------

def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
Expand Down
21 changes: 1 addition & 20 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,27 +557,8 @@ def _validate_comparison_value(self, other):

return other

def _validate_fill_value(self, fill_value):
"""
If a fill_value is passed to `take` convert it to an i8 representation,
raising TypeError if this is not possible.

Parameters
----------
fill_value : object

Returns
-------
fill_value : np.int64, np.datetime64, or np.timedelta64

Raises
------
TypeError
"""
return self._validate_scalar(fill_value)

def _validate_shift_value(self, fill_value):
# TODO(2.0): once this deprecation is enforced, use _validate_fill_value
# TODO(2.0): once this deprecation is enforced, use _validate_scalar
if is_valid_na_for_dtype(fill_value, self.dtype):
fill_value = NaT
elif isinstance(fill_value, self._recognized_scalars):
Expand Down
8 changes: 3 additions & 5 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def fillna(
if limit is not None:
raise TypeError("limit is not supported for IntervalArray.")

value_left, value_right = self._validate_fill_value(value)
value_left, value_right = self._validate_scalar(value)

left = self.left.fillna(value=value_left)
right = self.right.fillna(value=value_right)
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def take(

fill_left = fill_right = fill_value
if allow_fill:
fill_left, fill_right = self._validate_fill_value(fill_value)
fill_left, fill_right = self._validate_scalar(fill_value)

left_take = take(
self._left, indices, allow_fill=allow_fill, fill_value=fill_left
Expand Down Expand Up @@ -1037,6 +1037,7 @@ def _validate_scalar(self, value):
if isinstance(value, Interval):
self._check_closed_matches(value, name="value")
left, right = value.left, value.right
# TODO: check subdtype match like _validate_setitem_value?
elif is_valid_na_for_dtype(value, self.left.dtype):
# GH#18295
left = right = value
Expand All @@ -1046,9 +1047,6 @@ def _validate_scalar(self, value):
)
return left, right

def _validate_fill_value(self, value):
return self._validate_scalar(value)

def _validate_setitem_value(self, value):
needs_float_conversion = False

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
def isna(self) -> np.ndarray:
return isna(self._ndarray)

def _validate_fill_value(self, fill_value):
def _validate_scalar(self, fill_value):
if fill_value is None:
# Primarily for subclasses
fill_value = self.dtype.na_value
Expand Down