diff --git a/pandas/core/array_algos/transforms.py b/pandas/core/array_algos/transforms.py index f775b6d733d9c..b8b234d937292 100644 --- a/pandas/core/array_algos/transforms.py +++ b/pandas/core/array_algos/transforms.py @@ -10,6 +10,10 @@ def shift(values: np.ndarray, periods: int, axis: int, fill_value) -> np.ndarray: new_values = values + if periods == 0: + # TODO: should we copy here? + return new_values + # make sure array sent to np.roll is c_contiguous f_ordered = values.flags.f_contiguous if f_ordered: diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index af07dee3b6838..c5cac0cfeef6c 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -27,7 +27,6 @@ from pandas.core.dtypes.common import ( ensure_int64, ensure_object, - ensure_platform_int, is_categorical_dtype, is_datetime64_dtype, is_dict_like, @@ -51,6 +50,7 @@ from pandas.core.accessor import PandasDelegate, delegate_names import pandas.core.algorithms as algorithms from pandas.core.algorithms import _get_data_algo, factorize, take, take_1d, unique1d +from pandas.core.array_algos.transforms import shift from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs from pandas.core.base import NoNewAttributesMixin, PandasObject, _shared_docs import pandas.core.common as com @@ -1241,23 +1241,41 @@ def shift(self, periods, fill_value=None): codes = self.codes if codes.ndim > 1: raise NotImplementedError("Categorical with ndim > 1.") - if np.prod(codes.shape) and (periods != 0): - codes = np.roll(codes, ensure_platform_int(periods), axis=0) - if isna(fill_value): - fill_value = -1 - elif fill_value in self.categories: - fill_value = self.categories.get_loc(fill_value) - else: - raise ValueError( - f"'fill_value={fill_value}' is not present " - "in this Categorical's categories" - ) - if periods > 0: - codes[:periods] = fill_value - else: - codes[periods:] = fill_value - return self.from_codes(codes, dtype=self.dtype) + fill_value = self._validate_fill_value(fill_value) + + codes = shift(codes.copy(), periods, axis=0, fill_value=fill_value) + + return self._constructor(codes, dtype=self.dtype, fastpath=True) + + def _validate_fill_value(self, fill_value): + """ + Convert a user-facing fill_value to a representation to use with our + underlying ndarray, raising ValueError if this is not possible. + + Parameters + ---------- + fill_value : object + + Returns + ------- + fill_value : int + + Raises + ------ + ValueError + """ + + if isna(fill_value): + fill_value = -1 + elif fill_value in self.categories: + fill_value = self.categories.get_loc(fill_value) + else: + raise ValueError( + f"'fill_value={fill_value}' is not present " + "in this Categorical's categories" + ) + return fill_value def __array__(self, dtype=None) -> np.ndarray: """ @@ -1835,24 +1853,12 @@ def take(self, indexer, allow_fill: bool = False, fill_value=None): """ indexer = np.asarray(indexer, dtype=np.intp) - dtype = self.dtype - - if isna(fill_value): - fill_value = -1 - elif allow_fill: + if allow_fill: # convert user-provided `fill_value` to codes - if fill_value in self.categories: - fill_value = self.categories.get_loc(fill_value) - else: - msg = ( - f"'fill_value' ('{fill_value}') is not in this " - "Categorical's categories." - ) - raise TypeError(msg) + fill_value = self._validate_fill_value(fill_value) codes = take(self._codes, indexer, allow_fill=allow_fill, fill_value=fill_value) - result = type(self).from_codes(codes, dtype=dtype) - return result + return self._constructor(codes, dtype=self.dtype, fastpath=True) def take_nd(self, indexer, allow_fill: bool = False, fill_value=None): # GH#27745 deprecate alias that other EAs dont have diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index e3cdc898a88bf..430f20b359f8b 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -769,6 +769,7 @@ def shift(self, periods=1, fill_value=None, axis=0): if not self.size or periods == 0: return self.copy() + # TODO(2.0): once this deprecation is enforced, used _validate_fill_value if is_valid_nat_for_dtype(fill_value, self.dtype): fill_value = NaT elif not isinstance(fill_value, self._recognized_scalars): diff --git a/pandas/tests/arrays/categorical/test_algos.py b/pandas/tests/arrays/categorical/test_algos.py index 325fa476d70e6..45e0d503f30e7 100644 --- a/pandas/tests/arrays/categorical/test_algos.py +++ b/pandas/tests/arrays/categorical/test_algos.py @@ -184,8 +184,8 @@ def test_take_fill_value(self): def test_take_fill_value_new_raises(self): # https://github.com/pandas-dev/pandas/issues/23296 cat = pd.Categorical(["a", "b", "c"]) - xpr = r"'fill_value' \('d'\) is not in this Categorical's categories." - with pytest.raises(TypeError, match=xpr): + xpr = r"'fill_value=d' is not present in this Categorical's categories" + with pytest.raises(ValueError, match=xpr): cat.take([0, 1, -1], fill_value="d", allow_fill=True) def test_take_nd_deprecated(self): diff --git a/pandas/tests/frame/test_reshape.py b/pandas/tests/frame/test_reshape.py index 9d3c40ce926d7..2e707342a0793 100644 --- a/pandas/tests/frame/test_reshape.py +++ b/pandas/tests/frame/test_reshape.py @@ -320,9 +320,9 @@ def test_unstack_fill_frame_categorical(self): ) tm.assert_frame_equal(result, expected) - # Fill with non-category results in a TypeError - msg = r"'fill_value' \('d'\) is not in" - with pytest.raises(TypeError, match=msg): + # Fill with non-category results in a ValueError + msg = r"'fill_value=d' is not present in" + with pytest.raises(ValueError, match=msg): data.unstack(fill_value="d") # Fill with category value replaces missing values as expected