Skip to content

Commit e70d310

Browse files
authored
PERF: faster groupby diff (#45575)
1 parent 77d9237 commit e70d310

File tree

4 files changed

+70
-3
lines changed

4 files changed

+70
-3
lines changed

asv_bench/benchmarks/groupby.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
method_blocklist = {
2020
"object": {
21+
"diff",
2122
"median",
2223
"prod",
2324
"sem",
@@ -405,7 +406,7 @@ class GroupByMethods:
405406

406407
param_names = ["dtype", "method", "application", "ncols"]
407408
params = [
408-
["int", "float", "object", "datetime", "uint"],
409+
["int", "int16", "float", "object", "datetime", "uint"],
409410
[
410411
"all",
411412
"any",
@@ -417,6 +418,7 @@ class GroupByMethods:
417418
"cumprod",
418419
"cumsum",
419420
"describe",
421+
"diff",
420422
"ffill",
421423
"first",
422424
"head",
@@ -478,7 +480,7 @@ def setup(self, dtype, method, application, ncols):
478480
values = rng.take(taker, axis=0)
479481
if dtype == "int":
480482
key = np.random.randint(0, size, size=size)
481-
elif dtype == "uint":
483+
elif dtype in ("int16", "uint"):
482484
key = np.random.randint(0, size, size=size, dtype=dtype)
483485
elif dtype == "float":
484486
key = np.concatenate(

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ Performance improvements
288288
~~~~~~~~~~~~~~~~~~~~~~~~
289289
- Performance improvement in :meth:`.GroupBy.transform` for some user-defined DataFrame -> Series functions (:issue:`45387`)
290290
- Performance improvement in :meth:`DataFrame.duplicated` when subset consists of only one column (:issue:`45236`)
291+
- Performance improvement in :meth:`.GroupBy.diff` (:issue:`16706`)
291292
- Performance improvement in :meth:`.GroupBy.transform` when broadcasting values for user-defined functions (:issue:`45708`)
292293
- Performance improvement in :meth:`.GroupBy.transform` for user-defined functions when only a single group exists (:issue:`44977`)
293294
- Performance improvement in :meth:`MultiIndex.get_locs` (:issue:`45681`, :issue:`46040`)

pandas/core/groupby/groupby.py

+41
Original file line numberDiff line numberDiff line change
@@ -3456,6 +3456,47 @@ def shift(self, periods=1, freq=None, axis=0, fill_value=None):
34563456
)
34573457
return res
34583458

3459+
@final
3460+
@Substitution(name="groupby")
3461+
@Appender(_common_see_also)
3462+
def diff(self, periods: int = 1, axis: int = 0) -> Series | DataFrame:
3463+
"""
3464+
First discrete difference of element.
3465+
3466+
Calculates the difference of each element compared with another
3467+
element in the group (default is element in previous row).
3468+
3469+
Parameters
3470+
----------
3471+
periods : int, default 1
3472+
Periods to shift for calculating difference, accepts negative values.
3473+
axis : axis to shift, default 0
3474+
Take difference over rows (0) or columns (1).
3475+
3476+
Returns
3477+
-------
3478+
Series or DataFrame
3479+
First differences.
3480+
"""
3481+
if axis != 0:
3482+
return self.apply(lambda x: x.diff(periods=periods, axis=axis))
3483+
3484+
obj = self._obj_with_exclusions
3485+
shifted = self.shift(periods=periods, axis=axis)
3486+
3487+
# GH45562 - to retain existing behavior and match behavior of Series.diff(),
3488+
# int8 and int16 are coerced to float32 rather than float64.
3489+
dtypes_to_f32 = ["int8", "int16"]
3490+
if obj.ndim == 1:
3491+
if obj.dtype in dtypes_to_f32:
3492+
shifted = shifted.astype("float32")
3493+
else:
3494+
to_coerce = [c for c, dtype in obj.dtypes.items() if dtype in dtypes_to_f32]
3495+
if len(to_coerce):
3496+
shifted = shifted.astype({c: "float32" for c in to_coerce})
3497+
3498+
return obj - shifted
3499+
34593500
@final
34603501
@Substitution(name="groupby")
34613502
@Appender(_common_see_also)

pandas/tests/groupby/test_groupby_shift_diff.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_group_shift_lose_timezone():
6969
tm.assert_series_equal(result, expected)
7070

7171

72-
def test_group_diff_real(any_real_numpy_dtype):
72+
def test_group_diff_real_series(any_real_numpy_dtype):
7373
df = DataFrame(
7474
{"a": [1, 2, 3, 3, 2], "b": [1, 2, 3, 4, 5]},
7575
dtype=any_real_numpy_dtype,
@@ -82,6 +82,29 @@ def test_group_diff_real(any_real_numpy_dtype):
8282
tm.assert_series_equal(result, expected)
8383

8484

85+
def test_group_diff_real_frame(any_real_numpy_dtype):
86+
df = DataFrame(
87+
{
88+
"a": [1, 2, 3, 3, 2],
89+
"b": [1, 2, 3, 4, 5],
90+
"c": [1, 2, 3, 4, 6],
91+
},
92+
dtype=any_real_numpy_dtype,
93+
)
94+
result = df.groupby("a").diff()
95+
exp_dtype = "float"
96+
if any_real_numpy_dtype in ["int8", "int16", "float32"]:
97+
exp_dtype = "float32"
98+
expected = DataFrame(
99+
{
100+
"b": [np.nan, np.nan, np.nan, 1.0, 3.0],
101+
"c": [np.nan, np.nan, np.nan, 1.0, 4.0],
102+
},
103+
dtype=exp_dtype,
104+
)
105+
tm.assert_frame_equal(result, expected)
106+
107+
85108
@pytest.mark.parametrize(
86109
"data",
87110
[

0 commit comments

Comments
 (0)