Skip to content

REF: use array_algos shift for Categorical.shift #33663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pandas/core/array_algos/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
def shift(values: np.ndarray, periods: int, axis: int, fill_value) -> np.ndarray:
new_values = values

if periods == 0:
# TODO: should we copy here?
return new_values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should, do we elsewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we currently do a copy in 1/2 of places where this is called


# make sure array sent to np.roll is c_contiguous
f_ordered = values.flags.f_contiguous
if f_ordered:
Expand Down
70 changes: 38 additions & 32 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pandas.core.dtypes.common import (
ensure_int64,
ensure_object,
ensure_platform_int,
is_categorical_dtype,
is_datetime64_dtype,
is_dict_like,
Expand All @@ -51,6 +50,7 @@
from pandas.core.accessor import PandasDelegate, delegate_names
import pandas.core.algorithms as algorithms
from pandas.core.algorithms import _get_data_algo, factorize, take, take_1d, unique1d
from pandas.core.array_algos.transforms import shift
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
from pandas.core.base import NoNewAttributesMixin, PandasObject, _shared_docs
import pandas.core.common as com
Expand Down Expand Up @@ -1241,23 +1241,41 @@ def shift(self, periods, fill_value=None):
codes = self.codes
if codes.ndim > 1:
raise NotImplementedError("Categorical with ndim > 1.")
if np.prod(codes.shape) and (periods != 0):
codes = np.roll(codes, ensure_platform_int(periods), axis=0)
if isna(fill_value):
fill_value = -1
elif fill_value in self.categories:
fill_value = self.categories.get_loc(fill_value)
else:
raise ValueError(
f"'fill_value={fill_value}' is not present "
"in this Categorical's categories"
)
if periods > 0:
codes[:periods] = fill_value
else:
codes[periods:] = fill_value

return self.from_codes(codes, dtype=self.dtype)
fill_value = self._validate_fill_value(fill_value)

codes = shift(codes.copy(), periods, axis=0, fill_value=fill_value)

return self._constructor(codes, dtype=self.dtype, fastpath=True)

def _validate_fill_value(self, fill_value):
"""
Convert a user-facing fill_value to a representation to use with our
underlying ndarray, raising ValueError if this is not possible.

Parameters
----------
fill_value : object

Returns
-------
fill_value : int

Raises
------
ValueError
"""

if isna(fill_value):
fill_value = -1
elif fill_value in self.categories:
fill_value = self.categories.get_loc(fill_value)
else:
raise ValueError(
f"'fill_value={fill_value}' is not present "
"in this Categorical's categories"
)
return fill_value

def __array__(self, dtype=None) -> np.ndarray:
"""
Expand Down Expand Up @@ -1835,24 +1853,12 @@ def take(self, indexer, allow_fill: bool = False, fill_value=None):
"""
indexer = np.asarray(indexer, dtype=np.intp)

dtype = self.dtype

if isna(fill_value):
fill_value = -1
elif allow_fill:
if allow_fill:
# convert user-provided `fill_value` to codes
if fill_value in self.categories:
fill_value = self.categories.get_loc(fill_value)
else:
msg = (
f"'fill_value' ('{fill_value}') is not in this "
"Categorical's categories."
)
raise TypeError(msg)
fill_value = self._validate_fill_value(fill_value)

codes = take(self._codes, indexer, allow_fill=allow_fill, fill_value=fill_value)
result = type(self).from_codes(codes, dtype=dtype)
return result
return self._constructor(codes, dtype=self.dtype, fastpath=True)

def take_nd(self, indexer, allow_fill: bool = False, fill_value=None):
# GH#27745 deprecate alias that other EAs dont have
Expand Down
1 change: 1 addition & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ def shift(self, periods=1, fill_value=None, axis=0):
if not self.size or periods == 0:
return self.copy()

# TODO(2.0): once this deprecation is enforced, used _validate_fill_value
if is_valid_nat_for_dtype(fill_value, self.dtype):
fill_value = NaT
elif not isinstance(fill_value, self._recognized_scalars):
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/arrays/categorical/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def test_take_fill_value(self):
def test_take_fill_value_new_raises(self):
# https://github.com/pandas-dev/pandas/issues/23296
cat = pd.Categorical(["a", "b", "c"])
xpr = r"'fill_value' \('d'\) is not in this Categorical's categories."
with pytest.raises(TypeError, match=xpr):
xpr = r"'fill_value=d' is not present in this Categorical's categories"
with pytest.raises(ValueError, match=xpr):
cat.take([0, 1, -1], fill_value="d", allow_fill=True)

def test_take_nd_deprecated(self):
Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/frame/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def test_unstack_fill_frame_categorical(self):
)
tm.assert_frame_equal(result, expected)

# Fill with non-category results in a TypeError
msg = r"'fill_value' \('d'\) is not in"
with pytest.raises(TypeError, match=msg):
# Fill with non-category results in a ValueError
msg = r"'fill_value=d' is not present in"
with pytest.raises(ValueError, match=msg):
data.unstack(fill_value="d")

# Fill with category value replaces missing values as expected
Expand Down