Skip to content

ENH: Add skipna to groupby.first and groupby.last #57102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v2.2.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Bug fixes

Other
~~~~~
-
- Added the argument ``skipna`` to :meth:`DataFrameGroupBy.first`, :meth:`DataFrameGroupBy.last`, :meth:`SeriesGroupBy.first`, and :meth:`SeriesGroupBy.last`; achieving ``skipna=False`` used to be available via :meth:`DataFrameGroupBy.nth`, but the behavior was changed in pandas 2.0.0 (:issue:`57019`)
- Added the argument ``skipna`` to :meth:`Resampler.first`, :meth:`Resampler.last` (:issue:`57019`)

.. ---------------------------------------------------------------------------
.. _whatsnew_221.contributors:
Expand Down
2 changes: 2 additions & 0 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def group_last(
result_mask: npt.NDArray[np.bool_] | None = ...,
min_count: int = ..., # Py_ssize_t
is_datetimelike: bool = ...,
skipna: bool = ...,
) -> None: ...
def group_nth(
out: np.ndarray, # rank_t[:, ::1]
Expand All @@ -147,6 +148,7 @@ def group_nth(
min_count: int = ..., # int64_t
rank: int = ..., # int64_t
is_datetimelike: bool = ...,
skipna: bool = ...,
) -> None: ...
def group_rank(
out: np.ndarray, # float64_t[:, ::1]
Expand Down
41 changes: 26 additions & 15 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,7 @@ def group_last(
uint8_t[:, ::1] result_mask=None,
Py_ssize_t min_count=-1,
bint is_datetimelike=False,
bint skipna=True,
) -> None:
"""
Only aggregates on axis=0
Expand Down Expand Up @@ -1462,14 +1463,19 @@ def group_last(
for j in range(K):
val = values[i, j]

if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if skipna:
if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if isna_entry:
continue

if not isna_entry:
nobs[lab, j] += 1
resx[lab, j] = val
nobs[lab, j] += 1
resx[lab, j] = val

if uses_mask and not skipna:
result_mask[lab, j] = mask[i, j]

_check_below_mincount(
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
Expand All @@ -1490,6 +1496,7 @@ def group_nth(
int64_t min_count=-1,
int64_t rank=1,
bint is_datetimelike=False,
bint skipna=True,
) -> None:
"""
Only aggregates on axis=0
Expand Down Expand Up @@ -1524,15 +1531,19 @@ def group_nth(
for j in range(K):
val = values[i, j]

if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if skipna:
if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if isna_entry:
continue

if not isna_entry:
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
if uses_mask and not skipna:
result_mask[lab, j] = mask[i, j]

_check_below_mincount(
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
Expand Down
7 changes: 7 additions & 0 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,18 @@
+ TIMEDELTA_PYARROW_DTYPES
+ BOOL_PYARROW_DTYPES
)
ALL_REAL_PYARROW_DTYPES_STR_REPR = (
ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
)
else:
FLOAT_PYARROW_DTYPES_STR_REPR = []
ALL_INT_PYARROW_DTYPES_STR_REPR = []
ALL_PYARROW_DTYPES = []
ALL_REAL_PYARROW_DTYPES_STR_REPR = []

ALL_REAL_NULLABLE_DTYPES = (
FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
)

arithmetic_dunder_methods = [
"__add__",
Expand Down
32 changes: 32 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,38 @@ def any_numpy_dtype(request):
return request.param


@pytest.fixture(params=tm.ALL_REAL_NULLABLE_DTYPES)
def any_real_nullable_dtype(request):
"""
Parameterized fixture for all real dtypes that can hold NA.

* float
* 'float32'
* 'float64'
* 'Float32'
* 'Float64'
* 'UInt8'
* 'UInt16'
* 'UInt32'
* 'UInt64'
* 'Int8'
* 'Int16'
* 'Int32'
* 'Int64'
* 'uint8[pyarrow]'
* 'uint16[pyarrow]'
* 'uint32[pyarrow]'
* 'uint64[pyarrow]'
* 'int8[pyarrow]'
* 'int16[pyarrow]'
* 'int32[pyarrow]'
* 'int64[pyarrow]'
* 'float[pyarrow]'
* 'double[pyarrow]'
"""
return request.param


@pytest.fixture(params=tm.ALL_NUMERIC_DTYPES)
def any_numeric_dtype(request):
"""
Expand Down
36 changes: 28 additions & 8 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3364,22 +3364,31 @@ def max(
)

@final
def first(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
def first(
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
) -> NDFrameT:
"""
Compute the first non-null entry of each column.
Compute the first entry of each column within each group.

Defaults to skipping NA elements.

Parameters
----------
numeric_only : bool, default False
Include only float, int, boolean columns.
min_count : int, default -1
The required number of valid values to perform the operation. If fewer
than ``min_count`` non-NA values are present the result will be NA.
than ``min_count`` valid values are present the result will be NA.
skipna : bool, default True
Exclude NA/null values. If an entire row/column is NA, the result
will be NA.

.. versionadded:: 2.2.1

Returns
-------
Series or DataFrame
First non-null of values within each group.
First values within each group.

See Also
--------
Expand Down Expand Up @@ -3431,12 +3440,17 @@ def first(x: Series):
min_count=min_count,
alias="first",
npfunc=first_compat,
skipna=skipna,
)

@final
def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
def last(
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
) -> NDFrameT:
"""
Compute the last non-null entry of each column.
Compute the last entry of each column within each group.

Defaults to skipping NA elements.

Parameters
----------
Expand All @@ -3445,12 +3459,17 @@ def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
everything, then use only numeric data.
min_count : int, default -1
The required number of valid values to perform the operation. If fewer
than ``min_count`` non-NA values are present the result will be NA.
than ``min_count`` valid values are present the result will be NA.
skipna : bool, default True
Exclude NA/null values. If an entire row/column is NA, the result
will be NA.

.. versionadded:: 2.2.1

Returns
-------
Series or DataFrame
Last non-null of values within each group.
Last of values within each group.

See Also
--------
Expand Down Expand Up @@ -3490,6 +3509,7 @@ def last(x: Series):
min_count=min_count,
alias="last",
npfunc=last_compat,
skipna=skipna,
)

@final
Expand Down
10 changes: 8 additions & 2 deletions pandas/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,25 +1329,31 @@ def first(
self,
numeric_only: bool = False,
min_count: int = 0,
skipna: bool = True,
*args,
**kwargs,
):
maybe_warn_args_and_kwargs(type(self), "first", args, kwargs)
nv.validate_resampler_func("first", args, kwargs)
return self._downsample("first", numeric_only=numeric_only, min_count=min_count)
return self._downsample(
"first", numeric_only=numeric_only, min_count=min_count, skipna=skipna
)

@final
@doc(GroupBy.last)
def last(
self,
numeric_only: bool = False,
min_count: int = 0,
skipna: bool = True,
*args,
**kwargs,
):
maybe_warn_args_and_kwargs(type(self), "last", args, kwargs)
nv.validate_resampler_func("last", args, kwargs)
return self._downsample("last", numeric_only=numeric_only, min_count=min_count)
return self._downsample(
"last", numeric_only=numeric_only, min_count=min_count, skipna=skipna
)

@final
@doc(GroupBy.median)
Expand Down
31 changes: 31 additions & 0 deletions pandas/tests/groupby/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from pandas._libs.tslibs import iNaT

from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.missing import na_value_for_dtype

import pandas as pd
from pandas import (
DataFrame,
Expand Down Expand Up @@ -389,6 +392,34 @@ def test_groupby_non_arithmetic_agg_int_like_precision(method, data):
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("how", ["first", "last"])
def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how):
# GH#57019
na_value = na_value_for_dtype(pandas_dtype(any_real_nullable_dtype))
df = DataFrame(
{
"a": [2, 1, 1, 2, 3, 3],
"b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
"c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
},
dtype=any_real_nullable_dtype,
)
gb = df.groupby("a", sort=sort)
method = getattr(gb, how)
result = method(skipna=skipna)

ilocs = {
("first", True): [3, 1, 4],
("first", False): [0, 1, 4],
("last", True): [3, 1, 5],
("last", False): [3, 2, 5],
}[how, skipna]
expected = df.iloc[ilocs].set_index("a")
if sort:
expected = expected.sort_index()
tm.assert_frame_equal(result, expected)


def test_idxmin_idxmax_axis1():
df = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)), columns=["A", "B", "C", "D"]
Expand Down
29 changes: 29 additions & 0 deletions pandas/tests/resample/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_extension_array_dtype

import pandas as pd
from pandas import (
DataFrame,
DatetimeIndex,
Expand Down Expand Up @@ -459,3 +462,29 @@ def test_resample_quantile(index):
result = ser.resample(freq).quantile(q)
expected = ser.resample(freq).agg(lambda x: x.quantile(q)).rename(ser.name)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("how", ["first", "last"])
def test_first_last_skipna(any_real_nullable_dtype, skipna, how):
# GH#57019
if is_extension_array_dtype(any_real_nullable_dtype):
na_value = Series(dtype=any_real_nullable_dtype).dtype.na_value
else:
na_value = np.nan
df = DataFrame(
{
"a": [2, 1, 1, 2],
"b": [na_value, 3.0, na_value, 4.0],
"c": [na_value, 3.0, na_value, 4.0],
},
index=date_range("2020-01-01", periods=4, freq="D"),
dtype=any_real_nullable_dtype,
)
rs = df.resample("ME")
method = getattr(rs, how)
result = method(skipna=skipna)

gb = df.groupby(df.shape[0] * [pd.to_datetime("2020-01-31")])
expected = getattr(gb, how)(skipna=skipna)
expected.index.freq = "ME"
tm.assert_frame_equal(result, expected)
4 changes: 2 additions & 2 deletions pandas/tests/resample/test_resample_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,11 +1040,11 @@ def test_args_kwargs_depr(method, raises):
if raises:
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
with pytest.raises(UnsupportedFunctionCall, match=error_msg):
func(*args, 1, 2, 3)
func(*args, 1, 2, 3, 4)
else:
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
with pytest.raises(TypeError, match=error_msg_type):
func(*args, 1, 2, 3)
func(*args, 1, 2, 3, 4)


def test_df_axis_param_depr():
Expand Down