Skip to content

Commit ab3d4bf

Browse files
authored
ENH: Add skipna to groupby.first and groupby.last (#57102)
* ENH: Add skipna to groupby.first and groupby.last * resample & tests * Improve test * Fixups * fixup test * Rework na_value determination
1 parent dc5586b commit ab3d4bf

File tree

10 files changed

+167
-28
lines changed

10 files changed

+167
-28
lines changed

doc/source/whatsnew/v2.2.1.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ Bug fixes
3434

3535
Other
3636
~~~~~
37-
-
37+
- Added the argument ``skipna`` to :meth:`DataFrameGroupBy.first`, :meth:`DataFrameGroupBy.last`, :meth:`SeriesGroupBy.first`, and :meth:`SeriesGroupBy.last`; achieving ``skipna=False`` used to be available via :meth:`DataFrameGroupBy.nth`, but the behavior was changed in pandas 2.0.0 (:issue:`57019`)
38+
- Added the argument ``skipna`` to :meth:`Resampler.first`, :meth:`Resampler.last` (:issue:`57019`)
3839

3940
.. ---------------------------------------------------------------------------
4041
.. _whatsnew_221.contributors:

pandas/_libs/groupby.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def group_last(
136136
result_mask: npt.NDArray[np.bool_] | None = ...,
137137
min_count: int = ..., # Py_ssize_t
138138
is_datetimelike: bool = ...,
139+
skipna: bool = ...,
139140
) -> None: ...
140141
def group_nth(
141142
out: np.ndarray, # rank_t[:, ::1]
@@ -147,6 +148,7 @@ def group_nth(
147148
min_count: int = ..., # int64_t
148149
rank: int = ..., # int64_t
149150
is_datetimelike: bool = ...,
151+
skipna: bool = ...,
150152
) -> None: ...
151153
def group_rank(
152154
out: np.ndarray, # float64_t[:, ::1]

pandas/_libs/groupby.pyx

+26-15
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,7 @@ def group_last(
14281428
uint8_t[:, ::1] result_mask=None,
14291429
Py_ssize_t min_count=-1,
14301430
bint is_datetimelike=False,
1431+
bint skipna=True,
14311432
) -> None:
14321433
"""
14331434
Only aggregates on axis=0
@@ -1462,14 +1463,19 @@ def group_last(
14621463
for j in range(K):
14631464
val = values[i, j]
14641465

1465-
if uses_mask:
1466-
isna_entry = mask[i, j]
1467-
else:
1468-
isna_entry = _treat_as_na(val, is_datetimelike)
1466+
if skipna:
1467+
if uses_mask:
1468+
isna_entry = mask[i, j]
1469+
else:
1470+
isna_entry = _treat_as_na(val, is_datetimelike)
1471+
if isna_entry:
1472+
continue
14691473

1470-
if not isna_entry:
1471-
nobs[lab, j] += 1
1472-
resx[lab, j] = val
1474+
nobs[lab, j] += 1
1475+
resx[lab, j] = val
1476+
1477+
if uses_mask and not skipna:
1478+
result_mask[lab, j] = mask[i, j]
14731479

14741480
_check_below_mincount(
14751481
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
@@ -1490,6 +1496,7 @@ def group_nth(
14901496
int64_t min_count=-1,
14911497
int64_t rank=1,
14921498
bint is_datetimelike=False,
1499+
bint skipna=True,
14931500
) -> None:
14941501
"""
14951502
Only aggregates on axis=0
@@ -1524,15 +1531,19 @@ def group_nth(
15241531
for j in range(K):
15251532
val = values[i, j]
15261533

1527-
if uses_mask:
1528-
isna_entry = mask[i, j]
1529-
else:
1530-
isna_entry = _treat_as_na(val, is_datetimelike)
1534+
if skipna:
1535+
if uses_mask:
1536+
isna_entry = mask[i, j]
1537+
else:
1538+
isna_entry = _treat_as_na(val, is_datetimelike)
1539+
if isna_entry:
1540+
continue
15311541

1532-
if not isna_entry:
1533-
nobs[lab, j] += 1
1534-
if nobs[lab, j] == rank:
1535-
resx[lab, j] = val
1542+
nobs[lab, j] += 1
1543+
if nobs[lab, j] == rank:
1544+
resx[lab, j] = val
1545+
if uses_mask and not skipna:
1546+
result_mask[lab, j] = mask[i, j]
15361547

15371548
_check_below_mincount(
15381549
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx

pandas/_testing/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,18 @@
235235
+ TIMEDELTA_PYARROW_DTYPES
236236
+ BOOL_PYARROW_DTYPES
237237
)
238+
ALL_REAL_PYARROW_DTYPES_STR_REPR = (
239+
ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
240+
)
238241
else:
239242
FLOAT_PYARROW_DTYPES_STR_REPR = []
240243
ALL_INT_PYARROW_DTYPES_STR_REPR = []
241244
ALL_PYARROW_DTYPES = []
245+
ALL_REAL_PYARROW_DTYPES_STR_REPR = []
242246

247+
ALL_REAL_NULLABLE_DTYPES = (
248+
FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
249+
)
243250

244251
arithmetic_dunder_methods = [
245252
"__add__",

pandas/conftest.py

+32
Original file line numberDiff line numberDiff line change
@@ -1703,6 +1703,38 @@ def any_numpy_dtype(request):
17031703
return request.param
17041704

17051705

1706+
@pytest.fixture(params=tm.ALL_REAL_NULLABLE_DTYPES)
1707+
def any_real_nullable_dtype(request):
1708+
"""
1709+
Parameterized fixture for all real dtypes that can hold NA.
1710+
1711+
* float
1712+
* 'float32'
1713+
* 'float64'
1714+
* 'Float32'
1715+
* 'Float64'
1716+
* 'UInt8'
1717+
* 'UInt16'
1718+
* 'UInt32'
1719+
* 'UInt64'
1720+
* 'Int8'
1721+
* 'Int16'
1722+
* 'Int32'
1723+
* 'Int64'
1724+
* 'uint8[pyarrow]'
1725+
* 'uint16[pyarrow]'
1726+
* 'uint32[pyarrow]'
1727+
* 'uint64[pyarrow]'
1728+
* 'int8[pyarrow]'
1729+
* 'int16[pyarrow]'
1730+
* 'int32[pyarrow]'
1731+
* 'int64[pyarrow]'
1732+
* 'float[pyarrow]'
1733+
* 'double[pyarrow]'
1734+
"""
1735+
return request.param
1736+
1737+
17061738
@pytest.fixture(params=tm.ALL_NUMERIC_DTYPES)
17071739
def any_numeric_dtype(request):
17081740
"""

pandas/core/groupby/groupby.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -3364,22 +3364,31 @@ def max(
33643364
)
33653365

33663366
@final
3367-
def first(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
3367+
def first(
3368+
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
3369+
) -> NDFrameT:
33683370
"""
3369-
Compute the first non-null entry of each column.
3371+
Compute the first entry of each column within each group.
3372+
3373+
Defaults to skipping NA elements.
33703374
33713375
Parameters
33723376
----------
33733377
numeric_only : bool, default False
33743378
Include only float, int, boolean columns.
33753379
min_count : int, default -1
33763380
The required number of valid values to perform the operation. If fewer
3377-
than ``min_count`` non-NA values are present the result will be NA.
3381+
than ``min_count`` valid values are present the result will be NA.
3382+
skipna : bool, default True
3383+
Exclude NA/null values. If an entire row/column is NA, the result
3384+
will be NA.
3385+
3386+
.. versionadded:: 2.2.1
33783387
33793388
Returns
33803389
-------
33813390
Series or DataFrame
3382-
First non-null of values within each group.
3391+
First values within each group.
33833392
33843393
See Also
33853394
--------
@@ -3431,12 +3440,17 @@ def first(x: Series):
34313440
min_count=min_count,
34323441
alias="first",
34333442
npfunc=first_compat,
3443+
skipna=skipna,
34343444
)
34353445

34363446
@final
3437-
def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
3447+
def last(
3448+
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
3449+
) -> NDFrameT:
34383450
"""
3439-
Compute the last non-null entry of each column.
3451+
Compute the last entry of each column within each group.
3452+
3453+
Defaults to skipping NA elements.
34403454
34413455
Parameters
34423456
----------
@@ -3445,12 +3459,17 @@ def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
34453459
everything, then use only numeric data.
34463460
min_count : int, default -1
34473461
The required number of valid values to perform the operation. If fewer
3448-
than ``min_count`` non-NA values are present the result will be NA.
3462+
than ``min_count`` valid values are present the result will be NA.
3463+
skipna : bool, default True
3464+
Exclude NA/null values. If an entire row/column is NA, the result
3465+
will be NA.
3466+
3467+
.. versionadded:: 2.2.1
34493468
34503469
Returns
34513470
-------
34523471
Series or DataFrame
3453-
Last non-null of values within each group.
3472+
Last of values within each group.
34543473
34553474
See Also
34563475
--------
@@ -3490,6 +3509,7 @@ def last(x: Series):
34903509
min_count=min_count,
34913510
alias="last",
34923511
npfunc=last_compat,
3512+
skipna=skipna,
34933513
)
34943514

34953515
@final

pandas/core/resample.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1329,25 +1329,31 @@ def first(
13291329
self,
13301330
numeric_only: bool = False,
13311331
min_count: int = 0,
1332+
skipna: bool = True,
13321333
*args,
13331334
**kwargs,
13341335
):
13351336
maybe_warn_args_and_kwargs(type(self), "first", args, kwargs)
13361337
nv.validate_resampler_func("first", args, kwargs)
1337-
return self._downsample("first", numeric_only=numeric_only, min_count=min_count)
1338+
return self._downsample(
1339+
"first", numeric_only=numeric_only, min_count=min_count, skipna=skipna
1340+
)
13381341

13391342
@final
13401343
@doc(GroupBy.last)
13411344
def last(
13421345
self,
13431346
numeric_only: bool = False,
13441347
min_count: int = 0,
1348+
skipna: bool = True,
13451349
*args,
13461350
**kwargs,
13471351
):
13481352
maybe_warn_args_and_kwargs(type(self), "last", args, kwargs)
13491353
nv.validate_resampler_func("last", args, kwargs)
1350-
return self._downsample("last", numeric_only=numeric_only, min_count=min_count)
1354+
return self._downsample(
1355+
"last", numeric_only=numeric_only, min_count=min_count, skipna=skipna
1356+
)
13511357

13521358
@final
13531359
@doc(GroupBy.median)

pandas/tests/groupby/test_reductions.py

+31
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
from pandas._libs.tslibs import iNaT
99

10+
from pandas.core.dtypes.common import pandas_dtype
11+
from pandas.core.dtypes.missing import na_value_for_dtype
12+
1013
import pandas as pd
1114
from pandas import (
1215
DataFrame,
@@ -389,6 +392,34 @@ def test_groupby_non_arithmetic_agg_int_like_precision(method, data):
389392
tm.assert_frame_equal(result, expected)
390393

391394

395+
@pytest.mark.parametrize("how", ["first", "last"])
396+
def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how):
397+
# GH#57019
398+
na_value = na_value_for_dtype(pandas_dtype(any_real_nullable_dtype))
399+
df = DataFrame(
400+
{
401+
"a": [2, 1, 1, 2, 3, 3],
402+
"b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
403+
"c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
404+
},
405+
dtype=any_real_nullable_dtype,
406+
)
407+
gb = df.groupby("a", sort=sort)
408+
method = getattr(gb, how)
409+
result = method(skipna=skipna)
410+
411+
ilocs = {
412+
("first", True): [3, 1, 4],
413+
("first", False): [0, 1, 4],
414+
("last", True): [3, 1, 5],
415+
("last", False): [3, 2, 5],
416+
}[how, skipna]
417+
expected = df.iloc[ilocs].set_index("a")
418+
if sort:
419+
expected = expected.sort_index()
420+
tm.assert_frame_equal(result, expected)
421+
422+
392423
def test_idxmin_idxmax_axis1():
393424
df = DataFrame(
394425
np.random.default_rng(2).standard_normal((10, 4)), columns=["A", "B", "C", "D"]

pandas/tests/resample/test_base.py

+29
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import numpy as np
44
import pytest
55

6+
from pandas.core.dtypes.common import is_extension_array_dtype
7+
8+
import pandas as pd
69
from pandas import (
710
DataFrame,
811
DatetimeIndex,
@@ -459,3 +462,29 @@ def test_resample_quantile(index):
459462
result = ser.resample(freq).quantile(q)
460463
expected = ser.resample(freq).agg(lambda x: x.quantile(q)).rename(ser.name)
461464
tm.assert_series_equal(result, expected)
465+
466+
467+
@pytest.mark.parametrize("how", ["first", "last"])
468+
def test_first_last_skipna(any_real_nullable_dtype, skipna, how):
469+
# GH#57019
470+
if is_extension_array_dtype(any_real_nullable_dtype):
471+
na_value = Series(dtype=any_real_nullable_dtype).dtype.na_value
472+
else:
473+
na_value = np.nan
474+
df = DataFrame(
475+
{
476+
"a": [2, 1, 1, 2],
477+
"b": [na_value, 3.0, na_value, 4.0],
478+
"c": [na_value, 3.0, na_value, 4.0],
479+
},
480+
index=date_range("2020-01-01", periods=4, freq="D"),
481+
dtype=any_real_nullable_dtype,
482+
)
483+
rs = df.resample("ME")
484+
method = getattr(rs, how)
485+
result = method(skipna=skipna)
486+
487+
gb = df.groupby(df.shape[0] * [pd.to_datetime("2020-01-31")])
488+
expected = getattr(gb, how)(skipna=skipna)
489+
expected.index.freq = "ME"
490+
tm.assert_frame_equal(result, expected)

pandas/tests/resample/test_resample_api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1040,11 +1040,11 @@ def test_args_kwargs_depr(method, raises):
10401040
if raises:
10411041
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
10421042
with pytest.raises(UnsupportedFunctionCall, match=error_msg):
1043-
func(*args, 1, 2, 3)
1043+
func(*args, 1, 2, 3, 4)
10441044
else:
10451045
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
10461046
with pytest.raises(TypeError, match=error_msg_type):
1047-
func(*args, 1, 2, 3)
1047+
func(*args, 1, 2, 3, 4)
10481048

10491049

10501050
def test_df_axis_param_depr():

0 commit comments

Comments
 (0)