Skip to content

Commit 5d3fc2d

Browse files
committed
BUG,TST: Remove case where vectorization fails in pct_change groupby method
1 parent a745209 commit 5d3fc2d

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

pandas/core/groupby/generic.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1222,13 +1222,14 @@ def _apply_to_column_groupbys(self, func):
12221222

12231223
def pct_change(self, periods=1, fill_method='pad', limit=None, freq=None):
12241224
"""Calcuate pct_change of each value to previous entry in group"""
1225-
with _group_selection_context(self) as new:
1226-
if fill_method:
1227-
new = copy.copy(new)
1228-
new.obj = getattr(new, fill_method)(limit=limit)
1229-
new._reset_cache('_selected_obj')
1230-
shifted = new.shift(periods=periods, freq=freq)
1231-
return (new.obj / shifted) - 1
1225+
if freq:
1226+
return self.apply(lambda x: x.pct_change(periods=periods,
1227+
fill_method=fill_method,
1228+
limit=limit, freq=freq))
1229+
filled = getattr(self, fill_method)(limit=limit)
1230+
fill_grp = filled.groupby(self.grouper.labels)
1231+
shifted = fill_grp.shift(periods=periods, freq=freq)
1232+
return (filled / shifted) - 1
12321233

12331234

12341235
class DataFrameGroupBy(NDFrameGroupBy):

pandas/core/groupby/groupby.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1997,13 +1997,11 @@ def pct_change(self, periods=1, fill_method='pad', limit=None, freq=None,
19971997
fill_method=fill_method,
19981998
limit=limit, freq=freq,
19991999
axis=axis))
2000-
with _group_selection_context(self) as new:
2001-
if fill_method:
2002-
new = copy.copy(new)
2003-
new.obj = getattr(new, fill_method)(limit=limit)
2004-
obj = new.obj.drop(self.grouper.names, axis=1)
2005-
shifted = new.shift(periods=periods, freq=freq)
2006-
return (obj / shifted) - 1
2000+
filled = getattr(self, fill_method)(limit=limit)
2001+
filled = filled.drop(self.grouper.names, axis=1)
2002+
fill_grp = filled.groupby(self.grouper.labels)
2003+
shifted = fill_grp.shift(periods=periods, freq=freq)
2004+
return (filled/shifted)-1
20072005

20082006
@Substitution(name='groupby')
20092007
@Appender(_doc_template)

pandas/tests/groupby/test_transform.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -765,20 +765,28 @@ def test_pad_stable_sorting(fill_method):
765765

766766

767767
@pytest.mark.parametrize("test_series", [True, False])
768+
@pytest.mark.parametrize("test_freq", [True, False])
768769
@pytest.mark.parametrize("periods,fill_method,limit", [
769-
(1, None, None), (1, None, 1),
770770
(1, 'ffill', None), (1, 'ffill', 1),
771771
(1, 'bfill', None), (1, 'bfill', 1),
772772
(-1, 'ffill', None), (-1, 'ffill', 1),
773773
(-1, 'bfill', None), (-1, 'bfill', 1),
774774
])
775-
def test_pct_change(test_series, periods, fill_method, limit):
775+
def test_pct_change(test_series, test_freq, periods, fill_method, limit):
776776
# GH 21200, 21621
777777
vals = [3, np.nan, np.nan, np.nan, 1, 2, 4, 10, np.nan, 4]
778778
keys = ['a', 'b']
779779
key_v = np.repeat(keys, len(vals))
780780
df = DataFrame({'key': key_v, 'vals': vals * 2})
781781

782+
if test_freq:
783+
freq = 'D'
784+
dt_idx = pd.DatetimeIndex(start='2010-01-01', freq=freq,
785+
periods=len(vals))
786+
df.index = np.concatenate([dt_idx.values]*2)
787+
else:
788+
freq = None
789+
782790
if fill_method:
783791
df_g = getattr(df.groupby('key'), fill_method)(limit=limit)
784792
grp = df_g.groupby('key')
@@ -790,16 +798,17 @@ def test_pct_change(test_series, periods, fill_method, limit):
790798
if test_series:
791799
result = df.groupby('key')['vals'].pct_change(periods=periods,
792800
fill_method=fill_method,
793-
limit=limit)
801+
limit=limit,
802+
freq=freq)
794803
tm.assert_series_equal(result, expected)
795804
else:
796805
result = df.groupby('key').pct_change(periods=periods,
797806
fill_method=fill_method,
798-
limit=limit)
807+
limit=limit,
808+
freq=freq)
799809
tm.assert_frame_equal(result, expected.to_frame('vals'))
800810

801811

802-
803812
@pytest.mark.parametrize("func", [np.any, np.all])
804813
def test_any_all_np_func(func):
805814
# GH 20653

0 commit comments

Comments
 (0)