Skip to content

Commit 3089006

Browse files
committed
Merge pull request pandas-dev#10497 from bwillers/categorical_shift
BUG: CategoricalBlock shift GH9416
2 parents d7c31ca + 6955de6 commit 3089006

File tree

6 files changed

+83
-0
lines changed

6 files changed

+83
-0
lines changed

doc/source/whatsnew/v0.17.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ Bug Fixes
354354
- Bug in ``io.sql.get_schema`` when specifying multiple columns as primary
355355
key (:issue:`10385`).
356356
- Bug in ``test_categorical`` on big-endian builds (:issue:`10425`)
357+
- Bug in ``Series.shift`` and ``DataFrame.shift`` not supporting categorical data (:issue:`9416`)
357358
- Bug in ``Series.map`` using categorical ``Series`` raises ``AttributeError`` (:issue:`10324`)
358359
- Bug in ``MultiIndex.get_level_values`` including ``Categorical`` raises ``AttributeError`` (:issue:`10460`)
359360

pandas/core/categorical.py

+29
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,35 @@ def shape(self):
820820

821821
return tuple([len(self._codes)])
822822

823+
def shift(self, periods):
824+
"""
825+
Shift Categorical by desired number of periods.
826+
827+
Parameters
828+
----------
829+
periods : int
830+
Number of periods to move, can be positive or negative
831+
832+
Returns
833+
-------
834+
shifted : Categorical
835+
"""
836+
# since categoricals always have ndim == 1, an axis parameter
837+
# doesnt make any sense here.
838+
codes = self.codes
839+
if codes.ndim > 1:
840+
raise NotImplementedError("Categorical with ndim > 1.")
841+
if np.prod(codes.shape) and (periods != 0):
842+
codes = np.roll(codes, com._ensure_platform_int(periods), axis=0)
843+
if periods > 0:
844+
codes[:periods] = -1
845+
else:
846+
codes[periods:] = -1
847+
848+
return Categorical.from_codes(codes,
849+
categories=self.categories,
850+
ordered=self.ordered)
851+
823852
def __array__(self, dtype=None):
824853
"""
825854
The numpy array interface.

pandas/core/internals.py

+4
Original file line numberDiff line numberDiff line change
@@ -1716,6 +1716,10 @@ def interpolate(self, method='pad', axis=0, inplace=False,
17161716
limit=limit),
17171717
placement=self.mgr_locs)
17181718

1719+
def shift(self, periods, axis=0):
1720+
return self.make_block_same_class(values=self.values.shift(periods),
1721+
placement=self.mgr_locs)
1722+
17191723
def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None):
17201724
"""
17211725
Take values according to indexer and return them as a block.bb

pandas/tests/test_categorical.py

+20
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,26 @@ def test_set_item_nan(self):
10801080
exp = np.array([0,1,3,2])
10811081
self.assert_numpy_array_equal(cat.codes, exp)
10821082

1083+
def test_shift(self):
1084+
# GH 9416
1085+
cat = pd.Categorical(['a', 'b', 'c', 'd', 'a'])
1086+
1087+
# shift forward
1088+
sp1 = cat.shift(1)
1089+
xp1 = pd.Categorical([np.nan, 'a', 'b', 'c', 'd'])
1090+
self.assert_categorical_equal(sp1, xp1)
1091+
self.assert_categorical_equal(cat[:-1], sp1[1:])
1092+
1093+
# shift back
1094+
sn2 = cat.shift(-2)
1095+
xp2 = pd.Categorical(['c', 'd', 'a', np.nan, np.nan],
1096+
categories=['a', 'b', 'c', 'd'])
1097+
self.assert_categorical_equal(sn2, xp2)
1098+
self.assert_categorical_equal(cat[2:], sn2[:-2])
1099+
1100+
# shift by zero
1101+
self.assert_categorical_equal(cat, cat.shift(0))
1102+
10831103
def test_nbytes(self):
10841104
cat = pd.Categorical([1,2,3])
10851105
exp = cat._codes.nbytes + cat._categories.values.nbytes

pandas/tests/test_frame.py

+9
Original file line numberDiff line numberDiff line change
@@ -10363,6 +10363,15 @@ def test_shift_bool(self):
1036310363
columns=['high', 'low'])
1036410364
assert_frame_equal(rs, xp)
1036510365

10366+
def test_shift_categorical(self):
10367+
# GH 9416
10368+
s1 = pd.Series(['a', 'b', 'c'], dtype='category')
10369+
s2 = pd.Series(['A', 'B', 'C'], dtype='category')
10370+
df = DataFrame({'one': s1, 'two': s2})
10371+
rs = df.shift(1)
10372+
xp = DataFrame({'one': s1.shift(1), 'two': s2.shift(1)})
10373+
assert_frame_equal(rs, xp)
10374+
1036610375
def test_shift_empty(self):
1036710376
# Regression test for #8019
1036810377
df = DataFrame({'foo': []})

pandas/tests/test_series.py

+20
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pandas.util.testing import (assert_series_equal,
3737
assert_almost_equal,
3838
assert_frame_equal,
39+
assert_index_equal,
3940
ensure_clean)
4041
import pandas.util.testing as tm
4142

@@ -5255,6 +5256,25 @@ def test_shift_int(self):
52555256
expected = ts.astype(float).shift(1)
52565257
assert_series_equal(shifted, expected)
52575258

5259+
def test_shift_categorical(self):
5260+
# GH 9416
5261+
s = pd.Series(['a', 'b', 'c', 'd'], dtype='category')
5262+
5263+
assert_series_equal(s.iloc[:-1], s.shift(1).shift(-1).valid())
5264+
5265+
sp1 = s.shift(1)
5266+
assert_index_equal(s.index, sp1.index)
5267+
self.assertTrue(np.all(sp1.values.codes[:1] == -1))
5268+
self.assertTrue(np.all(s.values.codes[:-1] == sp1.values.codes[1:]))
5269+
5270+
sn2 = s.shift(-2)
5271+
assert_index_equal(s.index, sn2.index)
5272+
self.assertTrue(np.all(sn2.values.codes[-2:] == -1))
5273+
self.assertTrue(np.all(s.values.codes[2:] == sn2.values.codes[:-2]))
5274+
5275+
assert_index_equal(s.values.categories, sp1.values.categories)
5276+
assert_index_equal(s.values.categories, sn2.values.categories)
5277+
52585278
def test_truncate(self):
52595279
offset = datetools.bday
52605280

0 commit comments

Comments
 (0)