Skip to content

ENH: Add skipna to groupby.first and groupby.last #57102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v2.2.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ Bug fixes

Other
~~~~~
-
- Added the argument ``skipna`` to :meth:`DataFrameGroupBy.first`, :meth:`DataFrameGroupBy.last`, :meth:`SeriesGroupBy.first`, and :meth:`SeriesGroupBy.last`; achieving ``skipna=False`` used to be available via :meth:`DataFrameGroupBy.nth`, but the behavior was changed in pandas 2.0.0 (:issue:`57019`)
- Added the argument ``skipna`` to :meth:`Resampler.first`, :meth:`Resampler.last` (:issue:`57019`)

.. ---------------------------------------------------------------------------
.. _whatsnew_221.contributors:
Expand Down
2 changes: 2 additions & 0 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def group_last(
result_mask: npt.NDArray[np.bool_] | None = ...,
min_count: int = ..., # Py_ssize_t
is_datetimelike: bool = ...,
skipna: bool = ...,
) -> None: ...
def group_nth(
out: np.ndarray, # rank_t[:, ::1]
Expand All @@ -147,6 +148,7 @@ def group_nth(
min_count: int = ..., # int64_t
rank: int = ..., # int64_t
is_datetimelike: bool = ...,
skipna: bool = ...,
) -> None: ...
def group_rank(
out: np.ndarray, # float64_t[:, ::1]
Expand Down
41 changes: 26 additions & 15 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,7 @@ def group_last(
uint8_t[:, ::1] result_mask=None,
Py_ssize_t min_count=-1,
bint is_datetimelike=False,
bint skipna=True,
) -> None:
"""
Only aggregates on axis=0
Expand Down Expand Up @@ -1462,14 +1463,19 @@ def group_last(
for j in range(K):
val = values[i, j]

if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if skipna:
if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if isna_entry:
continue

if not isna_entry:
nobs[lab, j] += 1
resx[lab, j] = val
nobs[lab, j] += 1
resx[lab, j] = val

if uses_mask and not skipna:
result_mask[lab, j] = mask[i, j]

_check_below_mincount(
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
Expand All @@ -1490,6 +1496,7 @@ def group_nth(
int64_t min_count=-1,
int64_t rank=1,
bint is_datetimelike=False,
bint skipna=True,
) -> None:
"""
Only aggregates on axis=0
Expand Down Expand Up @@ -1524,15 +1531,19 @@ def group_nth(
for j in range(K):
val = values[i, j]

if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if skipna:
if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if isna_entry:
continue

if not isna_entry:
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
if uses_mask and not skipna:
result_mask[lab, j] = mask[i, j]

_check_below_mincount(
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
Expand Down
7 changes: 7 additions & 0 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,18 @@
+ TIMEDELTA_PYARROW_DTYPES
+ BOOL_PYARROW_DTYPES
)
ALL_REAL_PYARROW_DTYPES_STR_REPR = (
ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
)
else:
FLOAT_PYARROW_DTYPES_STR_REPR = []
ALL_INT_PYARROW_DTYPES_STR_REPR = []
ALL_PYARROW_DTYPES = []
ALL_REAL_PYARROW_DTYPES_STR_REPR = []

ALL_REAL_NULLABLE_DTYPES = (
FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
)

arithmetic_dunder_methods = [
"__add__",
Expand Down
32 changes: 32 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,38 @@ def any_numpy_dtype(request):
return request.param


@pytest.fixture(params=tm.ALL_REAL_NULLABLE_DTYPES)
def any_real_nullable_dtype(request):
"""
Parameterized fixture for all real dtypes that can hold NA.

* float
* 'float32'
* 'float64'
* 'Float32'
* 'Float64'
* 'UInt8'
* 'UInt16'
* 'UInt32'
* 'UInt64'
* 'Int8'
* 'Int16'
* 'Int32'
* 'Int64'
* 'uint8[pyarrow]'
* 'uint16[pyarrow]'
* 'uint32[pyarrow]'
* 'uint64[pyarrow]'
* 'int8[pyarrow]'
* 'int16[pyarrow]'
* 'int32[pyarrow]'
* 'int64[pyarrow]'
* 'float[pyarrow]'
* 'double[pyarrow]'
"""
return request.param


@pytest.fixture(params=tm.ALL_NUMERIC_DTYPES)
def any_numeric_dtype(request):
"""
Expand Down
36 changes: 28 additions & 8 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3364,22 +3364,31 @@ def max(
)

@final
def first(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
def first(
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
) -> NDFrameT:
"""
Compute the first non-null entry of each column.
Compute the first entry of each column within each group.

Defaults to skipping NA elements.

Parameters
----------
numeric_only : bool, default False
Include only float, int, boolean columns.
min_count : int, default -1
The required number of valid values to perform the operation. If fewer
than ``min_count`` non-NA values are present the result will be NA.
than ``min_count`` valid values are present the result will be NA.
skipna : bool, default True
Exclude NA/null values. If an entire row/column is NA, the result
will be NA.

.. versionadded:: 2.2.1

Returns
-------
Series or DataFrame
First non-null of values within each group.
First values within each group.

See Also
--------
Expand Down Expand Up @@ -3431,12 +3440,17 @@ def first(x: Series):
min_count=min_count,
alias="first",
npfunc=first_compat,
skipna=skipna,
)

@final
def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
def last(
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
) -> NDFrameT:
"""
Compute the last non-null entry of each column.
Compute the last entry of each column within each group.

Defaults to skipping NA elements.

Parameters
----------
Expand All @@ -3445,12 +3459,17 @@ def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
everything, then use only numeric data.
min_count : int, default -1
The required number of valid values to perform the operation. If fewer
than ``min_count`` non-NA values are present the result will be NA.
than ``min_count`` valid values are present the result will be NA.
skipna : bool, default True
Exclude NA/null values. If an entire row/column is NA, the result
will be NA.

.. versionadded:: 2.2.1

Returns
-------
Series or DataFrame
Last non-null of values within each group.
Last of values within each group.

See Also
--------
Expand Down Expand Up @@ -3490,6 +3509,7 @@ def last(x: Series):
min_count=min_count,
alias="last",
npfunc=last_compat,
skipna=skipna,
)

@final
Expand Down
10 changes: 8 additions & 2 deletions pandas/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,25 +1329,31 @@ def first(
self,
numeric_only: bool = False,
min_count: int = 0,
skipna: bool = True,
*args,
**kwargs,
):
maybe_warn_args_and_kwargs(type(self), "first", args, kwargs)
nv.validate_resampler_func("first", args, kwargs)
return self._downsample("first", numeric_only=numeric_only, min_count=min_count)
return self._downsample(
"first", numeric_only=numeric_only, min_count=min_count, skipna=skipna
)

@final
@doc(GroupBy.last)
def last(
self,
numeric_only: bool = False,
min_count: int = 0,
skipna: bool = True,
*args,
**kwargs,
):
maybe_warn_args_and_kwargs(type(self), "last", args, kwargs)
nv.validate_resampler_func("last", args, kwargs)
return self._downsample("last", numeric_only=numeric_only, min_count=min_count)
return self._downsample(
"last", numeric_only=numeric_only, min_count=min_count, skipna=skipna
)

@final
@doc(GroupBy.median)
Expand Down
33 changes: 33 additions & 0 deletions pandas/tests/groupby/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from pandas._libs.tslibs import iNaT

from pandas.core.dtypes.common import is_extension_array_dtype

import pandas as pd
from pandas import (
DataFrame,
Expand Down Expand Up @@ -389,6 +391,37 @@ def test_groupby_non_arithmetic_agg_int_like_precision(method, data):
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("how", ["first", "last"])
def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how):
# GH#57019
if is_extension_array_dtype(any_real_nullable_dtype):
na_value = Series(dtype=any_real_nullable_dtype).dtype.na_value
else:
na_value = np.nan
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if there is any better way to get the NA value for a dtype when the code needs to span NumPy and EAs (both masked and pyarrow). This is the reason why I went with the string aliases in the any_real_nullable_dtype fixture for pyarrow; this is at odds with some of the other fixtures but the code to the NA value was much worse with pyarrow dtype objects.

cc @mroeschke if you have any suggestions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use can use pandas_dtype to get a dtype object from the string and na_value_for_dtype to get the na value from the dtype object

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think we've been using isinstance(..., ExtensionDtype) instead of is_extension_array_dtype if possible

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use can use pandas_dtype to get a dtype object from the string

Makes sense - but I'd still need to have ALL_REAL_NULLABLE_DTYPES contain the string alias for pyarrow dtypes whereas most other lists of dtypes in _testing.__init__ use the pyarrow dtype objects. I just don't want to introduce an inconsistency here (pyarrow dtype objects vs string alias) if it's avoidable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just don't want to introduce an inconsistency here (pyarrow dtype objects vs string alias) if it's avoidable.

Yeah I don't think it's avoidable as of now, so I'm okay the way you have it in this PR

df = DataFrame(
{
"a": [2, 1, 1, 2, 3, 3],
"b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
"c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
},
dtype=any_real_nullable_dtype,
)
gb = df.groupby("a", sort=sort)
method = getattr(gb, how)
result = method(skipna=skipna)

ilocs = {
("first", True): [3, 1, 4],
("first", False): [0, 1, 4],
("last", True): [3, 1, 5],
("last", False): [3, 2, 5],
}[how, skipna]
expected = df.iloc[ilocs].set_index("a")
if sort:
expected = expected.sort_index()
tm.assert_frame_equal(result, expected)


def test_idxmin_idxmax_axis1():
df = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)), columns=["A", "B", "C", "D"]
Expand Down
29 changes: 29 additions & 0 deletions pandas/tests/resample/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_extension_array_dtype

import pandas as pd
from pandas import (
DataFrame,
DatetimeIndex,
Expand Down Expand Up @@ -459,3 +462,29 @@ def test_resample_quantile(index):
result = ser.resample(freq).quantile(q)
expected = ser.resample(freq).agg(lambda x: x.quantile(q)).rename(ser.name)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("how", ["first", "last"])
def test_first_last_skipna(any_real_nullable_dtype, skipna, how):
# GH#57019
if is_extension_array_dtype(any_real_nullable_dtype):
na_value = Series(dtype=any_real_nullable_dtype).dtype.na_value
else:
na_value = np.nan
df = DataFrame(
{
"a": [2, 1, 1, 2],
"b": [na_value, 3.0, na_value, 4.0],
"c": [na_value, 3.0, na_value, 4.0],
},
index=date_range("2020-01-01", periods=4, freq="D"),
dtype=any_real_nullable_dtype,
)
rs = df.resample("ME")
method = getattr(rs, how)
result = method(skipna=skipna)

gb = df.groupby(df.shape[0] * [pd.to_datetime("2020-01-31")])
expected = getattr(gb, how)(skipna=skipna)
expected.index.freq = "ME"
tm.assert_frame_equal(result, expected)