Skip to content

Commit acd914d

Browse files
Backport PR #57102 on branch 2.2.x (ENH: Add skipna to groupby.first and groupby.last) (#57141)
Backport PR #57102: ENH: Add skipna to groupby.first and groupby.last Co-authored-by: Richard Shadrach <[email protected]>
1 parent 10b5873 commit acd914d

File tree

10 files changed

+167
-28
lines changed

10 files changed

+167
-28
lines changed

doc/source/whatsnew/v2.2.1.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ Bug fixes
3434

3535
Other
3636
~~~~~
37-
-
37+
- Added the argument ``skipna`` to :meth:`DataFrameGroupBy.first`, :meth:`DataFrameGroupBy.last`, :meth:`SeriesGroupBy.first`, and :meth:`SeriesGroupBy.last`; achieving ``skipna=False`` used to be available via :meth:`DataFrameGroupBy.nth`, but the behavior was changed in pandas 2.0.0 (:issue:`57019`)
38+
- Added the argument ``skipna`` to :meth:`Resampler.first`, :meth:`Resampler.last` (:issue:`57019`)
3839

3940
.. ---------------------------------------------------------------------------
4041
.. _whatsnew_221.contributors:

pandas/_libs/groupby.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def group_last(
136136
result_mask: npt.NDArray[np.bool_] | None = ...,
137137
min_count: int = ..., # Py_ssize_t
138138
is_datetimelike: bool = ...,
139+
skipna: bool = ...,
139140
) -> None: ...
140141
def group_nth(
141142
out: np.ndarray, # rank_t[:, ::1]
@@ -147,6 +148,7 @@ def group_nth(
147148
min_count: int = ..., # int64_t
148149
rank: int = ..., # int64_t
149150
is_datetimelike: bool = ...,
151+
skipna: bool = ...,
150152
) -> None: ...
151153
def group_rank(
152154
out: np.ndarray, # float64_t[:, ::1]

pandas/_libs/groupby.pyx

+26-15
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,7 @@ def group_last(
14241424
uint8_t[:, ::1] result_mask=None,
14251425
Py_ssize_t min_count=-1,
14261426
bint is_datetimelike=False,
1427+
bint skipna=True,
14271428
) -> None:
14281429
"""
14291430
Only aggregates on axis=0
@@ -1458,14 +1459,19 @@ def group_last(
14581459
for j in range(K):
14591460
val = values[i, j]
14601461

1461-
if uses_mask:
1462-
isna_entry = mask[i, j]
1463-
else:
1464-
isna_entry = _treat_as_na(val, is_datetimelike)
1462+
if skipna:
1463+
if uses_mask:
1464+
isna_entry = mask[i, j]
1465+
else:
1466+
isna_entry = _treat_as_na(val, is_datetimelike)
1467+
if isna_entry:
1468+
continue
14651469

1466-
if not isna_entry:
1467-
nobs[lab, j] += 1
1468-
resx[lab, j] = val
1470+
nobs[lab, j] += 1
1471+
resx[lab, j] = val
1472+
1473+
if uses_mask and not skipna:
1474+
result_mask[lab, j] = mask[i, j]
14691475

14701476
_check_below_mincount(
14711477
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
@@ -1486,6 +1492,7 @@ def group_nth(
14861492
int64_t min_count=-1,
14871493
int64_t rank=1,
14881494
bint is_datetimelike=False,
1495+
bint skipna=True,
14891496
) -> None:
14901497
"""
14911498
Only aggregates on axis=0
@@ -1520,15 +1527,19 @@ def group_nth(
15201527
for j in range(K):
15211528
val = values[i, j]
15221529

1523-
if uses_mask:
1524-
isna_entry = mask[i, j]
1525-
else:
1526-
isna_entry = _treat_as_na(val, is_datetimelike)
1530+
if skipna:
1531+
if uses_mask:
1532+
isna_entry = mask[i, j]
1533+
else:
1534+
isna_entry = _treat_as_na(val, is_datetimelike)
1535+
if isna_entry:
1536+
continue
15271537

1528-
if not isna_entry:
1529-
nobs[lab, j] += 1
1530-
if nobs[lab, j] == rank:
1531-
resx[lab, j] = val
1538+
nobs[lab, j] += 1
1539+
if nobs[lab, j] == rank:
1540+
resx[lab, j] = val
1541+
if uses_mask and not skipna:
1542+
result_mask[lab, j] = mask[i, j]
15321543

15331544
_check_below_mincount(
15341545
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx

pandas/_testing/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,18 @@
236236
+ TIMEDELTA_PYARROW_DTYPES
237237
+ BOOL_PYARROW_DTYPES
238238
)
239+
ALL_REAL_PYARROW_DTYPES_STR_REPR = (
240+
ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
241+
)
239242
else:
240243
FLOAT_PYARROW_DTYPES_STR_REPR = []
241244
ALL_INT_PYARROW_DTYPES_STR_REPR = []
242245
ALL_PYARROW_DTYPES = []
246+
ALL_REAL_PYARROW_DTYPES_STR_REPR = []
243247

248+
ALL_REAL_NULLABLE_DTYPES = (
249+
FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
250+
)
244251

245252
arithmetic_dunder_methods = [
246253
"__add__",

pandas/conftest.py

+32
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,38 @@ def any_numpy_dtype(request):
16421642
return request.param
16431643

16441644

1645+
@pytest.fixture(params=tm.ALL_REAL_NULLABLE_DTYPES)
1646+
def any_real_nullable_dtype(request):
1647+
"""
1648+
Parameterized fixture for all real dtypes that can hold NA.
1649+
1650+
* float
1651+
* 'float32'
1652+
* 'float64'
1653+
* 'Float32'
1654+
* 'Float64'
1655+
* 'UInt8'
1656+
* 'UInt16'
1657+
* 'UInt32'
1658+
* 'UInt64'
1659+
* 'Int8'
1660+
* 'Int16'
1661+
* 'Int32'
1662+
* 'Int64'
1663+
* 'uint8[pyarrow]'
1664+
* 'uint16[pyarrow]'
1665+
* 'uint32[pyarrow]'
1666+
* 'uint64[pyarrow]'
1667+
* 'int8[pyarrow]'
1668+
* 'int16[pyarrow]'
1669+
* 'int32[pyarrow]'
1670+
* 'int64[pyarrow]'
1671+
* 'float[pyarrow]'
1672+
* 'double[pyarrow]'
1673+
"""
1674+
return request.param
1675+
1676+
16451677
@pytest.fixture(params=tm.ALL_NUMERIC_DTYPES)
16461678
def any_numeric_dtype(request):
16471679
"""

pandas/core/groupby/groupby.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -3335,22 +3335,31 @@ def max(
33353335
)
33363336

33373337
@final
3338-
def first(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
3338+
def first(
3339+
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
3340+
) -> NDFrameT:
33393341
"""
3340-
Compute the first non-null entry of each column.
3342+
Compute the first entry of each column within each group.
3343+
3344+
Defaults to skipping NA elements.
33413345
33423346
Parameters
33433347
----------
33443348
numeric_only : bool, default False
33453349
Include only float, int, boolean columns.
33463350
min_count : int, default -1
33473351
The required number of valid values to perform the operation. If fewer
3348-
than ``min_count`` non-NA values are present the result will be NA.
3352+
than ``min_count`` valid values are present the result will be NA.
3353+
skipna : bool, default True
3354+
Exclude NA/null values. If an entire row/column is NA, the result
3355+
will be NA.
3356+
3357+
.. versionadded:: 2.2.1
33493358
33503359
Returns
33513360
-------
33523361
Series or DataFrame
3353-
First non-null of values within each group.
3362+
First values within each group.
33543363
33553364
See Also
33563365
--------
@@ -3402,12 +3411,17 @@ def first(x: Series):
34023411
min_count=min_count,
34033412
alias="first",
34043413
npfunc=first_compat,
3414+
skipna=skipna,
34053415
)
34063416

34073417
@final
3408-
def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
3418+
def last(
3419+
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
3420+
) -> NDFrameT:
34093421
"""
3410-
Compute the last non-null entry of each column.
3422+
Compute the last entry of each column within each group.
3423+
3424+
Defaults to skipping NA elements.
34113425
34123426
Parameters
34133427
----------
@@ -3416,12 +3430,17 @@ def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
34163430
everything, then use only numeric data.
34173431
min_count : int, default -1
34183432
The required number of valid values to perform the operation. If fewer
3419-
than ``min_count`` non-NA values are present the result will be NA.
3433+
than ``min_count`` valid values are present the result will be NA.
3434+
skipna : bool, default True
3435+
Exclude NA/null values. If an entire row/column is NA, the result
3436+
will be NA.
3437+
3438+
.. versionadded:: 2.2.1
34203439
34213440
Returns
34223441
-------
34233442
Series or DataFrame
3424-
Last non-null of values within each group.
3443+
Last of values within each group.
34253444
34263445
See Also
34273446
--------
@@ -3461,6 +3480,7 @@ def last(x: Series):
34613480
min_count=min_count,
34623481
alias="last",
34633482
npfunc=last_compat,
3483+
skipna=skipna,
34643484
)
34653485

34663486
@final

pandas/core/resample.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1306,25 +1306,31 @@ def first(
13061306
self,
13071307
numeric_only: bool = False,
13081308
min_count: int = 0,
1309+
skipna: bool = True,
13091310
*args,
13101311
**kwargs,
13111312
):
13121313
maybe_warn_args_and_kwargs(type(self), "first", args, kwargs)
13131314
nv.validate_resampler_func("first", args, kwargs)
1314-
return self._downsample("first", numeric_only=numeric_only, min_count=min_count)
1315+
return self._downsample(
1316+
"first", numeric_only=numeric_only, min_count=min_count, skipna=skipna
1317+
)
13151318

13161319
@final
13171320
@doc(GroupBy.last)
13181321
def last(
13191322
self,
13201323
numeric_only: bool = False,
13211324
min_count: int = 0,
1325+
skipna: bool = True,
13221326
*args,
13231327
**kwargs,
13241328
):
13251329
maybe_warn_args_and_kwargs(type(self), "last", args, kwargs)
13261330
nv.validate_resampler_func("last", args, kwargs)
1327-
return self._downsample("last", numeric_only=numeric_only, min_count=min_count)
1331+
return self._downsample(
1332+
"last", numeric_only=numeric_only, min_count=min_count, skipna=skipna
1333+
)
13281334

13291335
@final
13301336
@doc(GroupBy.median)

pandas/tests/groupby/test_reductions.py

+31
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
from pandas._libs.tslibs import iNaT
99

10+
from pandas.core.dtypes.common import pandas_dtype
11+
from pandas.core.dtypes.missing import na_value_for_dtype
12+
1013
import pandas as pd
1114
from pandas import (
1215
DataFrame,
@@ -327,6 +330,34 @@ def test_groupby_non_arithmetic_agg_int_like_precision(method, data):
327330
tm.assert_frame_equal(result, expected)
328331

329332

333+
@pytest.mark.parametrize("how", ["first", "last"])
334+
def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how):
335+
# GH#57019
336+
na_value = na_value_for_dtype(pandas_dtype(any_real_nullable_dtype))
337+
df = DataFrame(
338+
{
339+
"a": [2, 1, 1, 2, 3, 3],
340+
"b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
341+
"c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
342+
},
343+
dtype=any_real_nullable_dtype,
344+
)
345+
gb = df.groupby("a", sort=sort)
346+
method = getattr(gb, how)
347+
result = method(skipna=skipna)
348+
349+
ilocs = {
350+
("first", True): [3, 1, 4],
351+
("first", False): [0, 1, 4],
352+
("last", True): [3, 1, 5],
353+
("last", False): [3, 2, 5],
354+
}[how, skipna]
355+
expected = df.iloc[ilocs].set_index("a")
356+
if sort:
357+
expected = expected.sort_index()
358+
tm.assert_frame_equal(result, expected)
359+
360+
330361
def test_idxmin_idxmax_axis1():
331362
df = DataFrame(
332363
np.random.default_rng(2).standard_normal((10, 4)), columns=["A", "B", "C", "D"]

pandas/tests/resample/test_base.py

+29
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import numpy as np
44
import pytest
55

6+
from pandas.core.dtypes.common import is_extension_array_dtype
7+
8+
import pandas as pd
69
from pandas import (
710
DataFrame,
811
DatetimeIndex,
@@ -429,3 +432,29 @@ def test_resample_quantile(series):
429432
result = ser.resample(freq).quantile(q)
430433
expected = ser.resample(freq).agg(lambda x: x.quantile(q)).rename(ser.name)
431434
tm.assert_series_equal(result, expected)
435+
436+
437+
@pytest.mark.parametrize("how", ["first", "last"])
438+
def test_first_last_skipna(any_real_nullable_dtype, skipna, how):
439+
# GH#57019
440+
if is_extension_array_dtype(any_real_nullable_dtype):
441+
na_value = Series(dtype=any_real_nullable_dtype).dtype.na_value
442+
else:
443+
na_value = np.nan
444+
df = DataFrame(
445+
{
446+
"a": [2, 1, 1, 2],
447+
"b": [na_value, 3.0, na_value, 4.0],
448+
"c": [na_value, 3.0, na_value, 4.0],
449+
},
450+
index=date_range("2020-01-01", periods=4, freq="D"),
451+
dtype=any_real_nullable_dtype,
452+
)
453+
rs = df.resample("ME")
454+
method = getattr(rs, how)
455+
result = method(skipna=skipna)
456+
457+
gb = df.groupby(df.shape[0] * [pd.to_datetime("2020-01-31")])
458+
expected = getattr(gb, how)(skipna=skipna)
459+
expected.index.freq = "ME"
460+
tm.assert_frame_equal(result, expected)

pandas/tests/resample/test_resample_api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1040,11 +1040,11 @@ def test_args_kwargs_depr(method, raises):
10401040
if raises:
10411041
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
10421042
with pytest.raises(UnsupportedFunctionCall, match=error_msg):
1043-
func(*args, 1, 2, 3)
1043+
func(*args, 1, 2, 3, 4)
10441044
else:
10451045
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
10461046
with pytest.raises(TypeError, match=error_msg_type):
1047-
func(*args, 1, 2, 3)
1047+
func(*args, 1, 2, 3, 4)
10481048

10491049

10501050
def test_df_axis_param_depr():

0 commit comments

Comments
 (0)