Skip to content

Commit 9376960

Browse files
authored
REF: use array_algos shift for Categorical.shift (#33663)
1 parent bf810ff commit 9376960

File tree

5 files changed

+48
-37
lines changed

5 files changed

+48
-37
lines changed

pandas/core/array_algos/transforms.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
def shift(values: np.ndarray, periods: int, axis: int, fill_value) -> np.ndarray:
1111
new_values = values
1212

13+
if periods == 0:
14+
# TODO: should we copy here?
15+
return new_values
16+
1317
# make sure array sent to np.roll is c_contiguous
1418
f_ordered = values.flags.f_contiguous
1519
if f_ordered:

pandas/core/arrays/categorical.py

+38-32
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from pandas.core.dtypes.common import (
2828
ensure_int64,
2929
ensure_object,
30-
ensure_platform_int,
3130
is_categorical_dtype,
3231
is_datetime64_dtype,
3332
is_dict_like,
@@ -51,6 +50,7 @@
5150
from pandas.core.accessor import PandasDelegate, delegate_names
5251
import pandas.core.algorithms as algorithms
5352
from pandas.core.algorithms import _get_data_algo, factorize, take, take_1d, unique1d
53+
from pandas.core.array_algos.transforms import shift
5454
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
5555
from pandas.core.base import NoNewAttributesMixin, PandasObject, _shared_docs
5656
import pandas.core.common as com
@@ -1241,23 +1241,41 @@ def shift(self, periods, fill_value=None):
12411241
codes = self.codes
12421242
if codes.ndim > 1:
12431243
raise NotImplementedError("Categorical with ndim > 1.")
1244-
if np.prod(codes.shape) and (periods != 0):
1245-
codes = np.roll(codes, ensure_platform_int(periods), axis=0)
1246-
if isna(fill_value):
1247-
fill_value = -1
1248-
elif fill_value in self.categories:
1249-
fill_value = self.categories.get_loc(fill_value)
1250-
else:
1251-
raise ValueError(
1252-
f"'fill_value={fill_value}' is not present "
1253-
"in this Categorical's categories"
1254-
)
1255-
if periods > 0:
1256-
codes[:periods] = fill_value
1257-
else:
1258-
codes[periods:] = fill_value
12591244

1260-
return self.from_codes(codes, dtype=self.dtype)
1245+
fill_value = self._validate_fill_value(fill_value)
1246+
1247+
codes = shift(codes.copy(), periods, axis=0, fill_value=fill_value)
1248+
1249+
return self._constructor(codes, dtype=self.dtype, fastpath=True)
1250+
1251+
def _validate_fill_value(self, fill_value):
1252+
"""
1253+
Convert a user-facing fill_value to a representation to use with our
1254+
underlying ndarray, raising ValueError if this is not possible.
1255+
1256+
Parameters
1257+
----------
1258+
fill_value : object
1259+
1260+
Returns
1261+
-------
1262+
fill_value : int
1263+
1264+
Raises
1265+
------
1266+
ValueError
1267+
"""
1268+
1269+
if isna(fill_value):
1270+
fill_value = -1
1271+
elif fill_value in self.categories:
1272+
fill_value = self.categories.get_loc(fill_value)
1273+
else:
1274+
raise ValueError(
1275+
f"'fill_value={fill_value}' is not present "
1276+
"in this Categorical's categories"
1277+
)
1278+
return fill_value
12611279

12621280
def __array__(self, dtype=None) -> np.ndarray:
12631281
"""
@@ -1835,24 +1853,12 @@ def take(self, indexer, allow_fill: bool = False, fill_value=None):
18351853
"""
18361854
indexer = np.asarray(indexer, dtype=np.intp)
18371855

1838-
dtype = self.dtype
1839-
1840-
if isna(fill_value):
1841-
fill_value = -1
1842-
elif allow_fill:
1856+
if allow_fill:
18431857
# convert user-provided `fill_value` to codes
1844-
if fill_value in self.categories:
1845-
fill_value = self.categories.get_loc(fill_value)
1846-
else:
1847-
msg = (
1848-
f"'fill_value' ('{fill_value}') is not in this "
1849-
"Categorical's categories."
1850-
)
1851-
raise TypeError(msg)
1858+
fill_value = self._validate_fill_value(fill_value)
18521859

18531860
codes = take(self._codes, indexer, allow_fill=allow_fill, fill_value=fill_value)
1854-
result = type(self).from_codes(codes, dtype=dtype)
1855-
return result
1861+
return self._constructor(codes, dtype=self.dtype, fastpath=True)
18561862

18571863
def take_nd(self, indexer, allow_fill: bool = False, fill_value=None):
18581864
# GH#27745 deprecate alias that other EAs dont have

pandas/core/arrays/datetimelike.py

+1
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ def shift(self, periods=1, fill_value=None, axis=0):
769769
if not self.size or periods == 0:
770770
return self.copy()
771771

772+
# TODO(2.0): once this deprecation is enforced, used _validate_fill_value
772773
if is_valid_nat_for_dtype(fill_value, self.dtype):
773774
fill_value = NaT
774775
elif not isinstance(fill_value, self._recognized_scalars):

pandas/tests/arrays/categorical/test_algos.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def test_take_fill_value(self):
184184
def test_take_fill_value_new_raises(self):
185185
# https://github.com/pandas-dev/pandas/issues/23296
186186
cat = pd.Categorical(["a", "b", "c"])
187-
xpr = r"'fill_value' \('d'\) is not in this Categorical's categories."
188-
with pytest.raises(TypeError, match=xpr):
187+
xpr = r"'fill_value=d' is not present in this Categorical's categories"
188+
with pytest.raises(ValueError, match=xpr):
189189
cat.take([0, 1, -1], fill_value="d", allow_fill=True)
190190

191191
def test_take_nd_deprecated(self):

pandas/tests/frame/test_reshape.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,9 @@ def test_unstack_fill_frame_categorical(self):
320320
)
321321
tm.assert_frame_equal(result, expected)
322322

323-
# Fill with non-category results in a TypeError
324-
msg = r"'fill_value' \('d'\) is not in"
325-
with pytest.raises(TypeError, match=msg):
323+
# Fill with non-category results in a ValueError
324+
msg = r"'fill_value=d' is not present in"
325+
with pytest.raises(ValueError, match=msg):
326326
data.unstack(fill_value="d")
327327

328328
# Fill with category value replaces missing values as expected

0 commit comments

Comments
 (0)