diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 7a7050ea8bad7..48822d9d01ddb 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -550,10 +550,12 @@ def apply_str(self) -> DataFrame | Series: func = getattr(obj, f, None) if callable(func): sig = inspect.getfullargspec(func) - if "axis" in sig.args: - self.kwargs["axis"] = self.axis - elif self.axis != 0: + if self.axis != 0 and ( + "axis" not in sig.args or f in ("corrwith", "mad", "skew") + ): raise ValueError(f"Operation {f} does not support axis=1") + elif "axis" in sig.args: + self.kwargs["axis"] = self.axis return self._try_aggregate_string_function(obj, f, *self.args, **self.kwargs) def apply_multiple(self) -> DataFrame | Series: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 32d4ac24a1d53..7a201c1d9202f 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -11657,10 +11657,8 @@ def all(self, axis=0, bool_only=None, skipna=True, level=None, **kwargs): setattr(cls, "all", all) - # error: Argument 1 to "doc" has incompatible type "Optional[str]"; expected - # "Union[str, Callable[..., Any]]" @doc( - NDFrame.mad.__doc__, # type: ignore[arg-type] + NDFrame.mad.__doc__, desc="Return the mean absolute deviation of the values " "over the requested axis.", name1=name1, diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index ad1f36e0cddd8..953fc4673a38e 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -1,7 +1,5 @@ """ -Provide basic components for groupby. These definitions -hold the allowlist of methods that are exposed on the -SeriesGroupBy and the DataFrameGroupBy objects. +Provide basic components for groupby. """ from __future__ import annotations @@ -22,36 +20,6 @@ class OutputKey: # forwarding methods from NDFrames plotting_methods = frozenset(["plot", "hist"]) -common_apply_allowlist = ( - frozenset( - [ - "quantile", - "fillna", - "mad", - "take", - "idxmax", - "idxmin", - "tshift", - "skew", - "corr", - "cov", - "diff", - ] - ) - | plotting_methods -) - -series_apply_allowlist: frozenset[str] = ( - common_apply_allowlist - | frozenset( - {"nlargest", "nsmallest", "is_monotonic_increasing", "is_monotonic_decreasing"} - ) -) | frozenset(["dtype", "unique"]) - -dataframe_apply_allowlist: frozenset[str] = common_apply_allowlist | frozenset( - ["dtypes", "corrwith"] -) - # cythonized transformations or canned "agg+broadcast", which do not # require postprocessing of the result by transform. cythonized_kernels = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"]) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 7fe1d55ba55be..c06042915cbc2 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -17,6 +17,7 @@ Callable, Hashable, Iterable, + Literal, Mapping, NamedTuple, Sequence, @@ -35,9 +36,14 @@ ) from pandas._typing import ( ArrayLike, + Axis, + FillnaOptions, + IndexLabel, + Level, Manager, Manager2D, SingleManager, + TakeIndexer, ) from pandas.errors import SpecificationError from pandas.util._decorators import ( @@ -78,6 +84,7 @@ from pandas.core.groupby import base from pandas.core.groupby.groupby import ( GroupBy, + GroupByPlot, _agg_template, _apply_docs, _transform_template, @@ -135,48 +142,7 @@ def prop(self): return property(prop) -def pin_allowlisted_properties( - klass: type[DataFrame | Series], allowlist: frozenset[str] -): - """ - Create GroupBy member defs for DataFrame/Series names in a allowlist. - - Parameters - ---------- - klass : DataFrame or Series class - class where members are defined. - allowlist : frozenset[str] - Set of names of klass methods to be constructed - - Returns - ------- - class decorator - - Notes - ----- - Since we don't want to override methods explicitly defined in the - base class, any such name is skipped. - """ - - def pinner(cls): - for name in allowlist: - if hasattr(cls, name): - # don't override anything that was explicitly defined - # in the base class - continue - - prop = generate_property(name, klass) - setattr(cls, name, prop) - - return cls - - return pinner - - -@pin_allowlisted_properties(Series, base.series_apply_allowlist) class SeriesGroupBy(GroupBy[Series]): - _apply_allowlist = base.series_apply_allowlist - def _wrap_agged_manager(self, mgr: Manager) -> Series: if mgr.ndim == 1: mgr = cast(SingleManager, mgr) @@ -754,8 +720,82 @@ def build_codes(lev_codes: np.ndarray) -> np.ndarray: out = ensure_int64(out) return self.obj._constructor(out, index=mi, name=self.obj.name) - @doc(Series.nlargest) - def nlargest(self, n: int = 5, keep: str = "first") -> Series: + @doc(Series.fillna.__doc__) + def fillna( + self, + value: object | ArrayLike | None = None, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, + ) -> Series | None: + result = self._op_via_apply( + "fillna", + value=value, + method=method, + axis=axis, + inplace=inplace, + limit=limit, + downcast=downcast, + ) + return result + + @doc(Series.take.__doc__) + def take( + self, + indices: TakeIndexer, + axis: Axis = 0, + is_copy: bool | None = None, + **kwargs, + ) -> Series: + result = self._op_via_apply( + "take", indices=indices, axis=axis, is_copy=is_copy, **kwargs + ) + return result + + @doc(Series.skew.__doc__) + def skew( + self, + axis: Axis | lib.NoDefault = lib.no_default, + skipna: bool = True, + level: Level | None = None, + numeric_only: bool | None = None, + **kwargs, + ) -> Series: + result = self._op_via_apply( + "skew", + axis=axis, + skipna=skipna, + level=level, + numeric_only=numeric_only, + **kwargs, + ) + return result + + @doc(Series.mad.__doc__) + def mad( + self, axis: Axis | None = None, skipna: bool = True, level: Level | None = None + ) -> Series: + result = self._op_via_apply("mad", axis=axis, skipna=skipna, level=level) + return result + + @doc(Series.tshift.__doc__) + def tshift(self, periods: int = 1, freq=None) -> Series: + result = self._op_via_apply("tshift", periods=periods, freq=freq) + return result + + # Decorated property not supported - https://github.com/python/mypy/issues/1362 + @property # type: ignore[misc] + @doc(Series.plot.__doc__) + def plot(self): + result = GroupByPlot(self) + return result + + @doc(Series.nlargest.__doc__) + def nlargest( + self, n: int = 5, keep: Literal["first", "last", "all"] = "first" + ) -> Series: f = partial(Series.nlargest, n=n, keep=keep) data = self._obj_with_exclusions # Don't change behavior if result index happens to be the same, i.e. @@ -763,8 +803,10 @@ def nlargest(self, n: int = 5, keep: str = "first") -> Series: result = self._python_apply_general(f, data, not_indexed_same=True) return result - @doc(Series.nsmallest) - def nsmallest(self, n: int = 5, keep: str = "first") -> Series: + @doc(Series.nsmallest.__doc__) + def nsmallest( + self, n: int = 5, keep: Literal["first", "last", "all"] = "first" + ) -> Series: f = partial(Series.nsmallest, n=n, keep=keep) data = self._obj_with_exclusions # Don't change behavior if result index happens to be the same, i.e. @@ -772,11 +814,99 @@ def nsmallest(self, n: int = 5, keep: str = "first") -> Series: result = self._python_apply_general(f, data, not_indexed_same=True) return result + @doc(Series.idxmin.__doc__) + def idxmin(self, axis: Axis = 0, skipna: bool = True) -> Series: + result = self._op_via_apply("idxmin", axis=axis, skipna=skipna) + return result -@pin_allowlisted_properties(DataFrame, base.dataframe_apply_allowlist) -class DataFrameGroupBy(GroupBy[DataFrame]): + @doc(Series.idxmax.__doc__) + def idxmax(self, axis: Axis = 0, skipna: bool = True) -> Series: + result = self._op_via_apply("idxmax", axis=axis, skipna=skipna) + return result + + @doc(Series.corr.__doc__) + def corr( + self, + other: Series, + method: Literal["pearson", "kendall", "spearman"] + | Callable[[np.ndarray, np.ndarray], float] = "pearson", + min_periods: int | None = None, + ) -> Series: + result = self._op_via_apply( + "corr", other=other, method=method, min_periods=min_periods + ) + return result + + @doc(Series.cov.__doc__) + def cov( + self, other: Series, min_periods: int | None = None, ddof: int | None = 1 + ) -> Series: + result = self._op_via_apply( + "cov", other=other, min_periods=min_periods, ddof=ddof + ) + return result + + # Decorated property not supported - https://github.com/python/mypy/issues/1362 + @property # type: ignore[misc] + @doc(Series.is_monotonic_increasing.__doc__) + def is_monotonic_increasing(self) -> Series: + result = self._op_via_apply("is_monotonic_increasing") + return result + + # Decorated property not supported - https://github.com/python/mypy/issues/1362 + @property # type: ignore[misc] + @doc(Series.is_monotonic_decreasing.__doc__) + def is_monotonic_decreasing(self) -> Series: + result = self._op_via_apply("is_monotonic_decreasing") + return result + + @doc(Series.hist.__doc__) + def hist( + self, + by=None, + ax=None, + grid: bool = True, + xlabelsize: int | None = None, + xrot: float | None = None, + ylabelsize: int | None = None, + yrot: float | None = None, + figsize: tuple[int, int] | None = None, + bins: int | Sequence[int] = 10, + backend: str | None = None, + legend: bool = False, + **kwargs, + ): + result = self._op_via_apply( + "hist", + by=by, + ax=ax, + grid=grid, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + figsize=figsize, + bins=bins, + backend=backend, + legend=legend, + **kwargs, + ) + return result + + # Decorated property not supported - https://github.com/python/mypy/issues/1362 + @property # type: ignore[misc] + @doc(Series.dtype.__doc__) + def dtype(self) -> Series: + result = self._op_via_apply("dtype") + return result + + @doc(Series.unique.__doc__) + def unique(self) -> Series: + result = self._op_via_apply("unique") + return result - _apply_allowlist = base.dataframe_apply_allowlist + +class DataFrameGroupBy(GroupBy[DataFrame]): _agg_examples_doc = dedent( """ @@ -1911,6 +2041,169 @@ def value_counts( result = result_frame return result.__finalize__(self.obj, method="value_counts") + @doc(DataFrame.fillna.__doc__) + def fillna( + self, + value: Hashable | Mapping | Series | DataFrame = None, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit=None, + downcast=None, + ) -> DataFrame | None: + result = self._op_via_apply( + "fillna", + value=value, + method=method, + axis=axis, + inplace=inplace, + limit=limit, + downcast=downcast, + ) + return result + + @doc(DataFrame.take.__doc__) + def take( + self, + indices: TakeIndexer, + axis: Axis | None = 0, + is_copy: bool | None = None, + **kwargs, + ) -> DataFrame: + result = self._op_via_apply( + "take", indices=indices, axis=axis, is_copy=is_copy, **kwargs + ) + return result + + @doc(DataFrame.skew.__doc__) + def skew( + self, + axis: Axis | None | lib.NoDefault = lib.no_default, + skipna: bool = True, + level: Level | None = None, + numeric_only: bool | lib.NoDefault = lib.no_default, + **kwargs, + ) -> DataFrame: + result = self._op_via_apply( + "skew", + axis=axis, + skipna=skipna, + level=level, + numeric_only=numeric_only, + **kwargs, + ) + return result + + @doc(DataFrame.mad.__doc__) + def mad( + self, axis: Axis | None = None, skipna: bool = True, level: Level | None = None + ) -> DataFrame: + result = self._op_via_apply("mad", axis=axis, skipna=skipna, level=level) + return result + + @doc(DataFrame.tshift.__doc__) + def tshift(self, periods: int = 1, freq=None, axis: Axis = 0) -> DataFrame: + result = self._op_via_apply("tshift", periods=periods, freq=freq, axis=axis) + return result + + @property # type: ignore[misc] + @doc(DataFrame.plot.__doc__) + def plot(self) -> GroupByPlot: + result = GroupByPlot(self) + return result + + @doc(DataFrame.corr.__doc__) + def corr( + self, + method: str | Callable[[np.ndarray, np.ndarray], float] = "pearson", + min_periods: int = 1, + numeric_only: bool | lib.NoDefault = lib.no_default, + ) -> DataFrame: + result = self._op_via_apply( + "corr", method=method, min_periods=min_periods, numeric_only=numeric_only + ) + return result + + @doc(DataFrame.cov.__doc__) + def cov( + self, + min_periods: int | None = None, + ddof: int | None = 1, + numeric_only: bool | lib.NoDefault = lib.no_default, + ) -> DataFrame: + result = self._op_via_apply( + "cov", min_periods=min_periods, ddof=ddof, numeric_only=numeric_only + ) + return result + + @doc(DataFrame.hist.__doc__) + def hist( + self, + column: IndexLabel = None, + by=None, + grid: bool = True, + xlabelsize: int | None = None, + xrot: float | None = None, + ylabelsize: int | None = None, + yrot: float | None = None, + ax=None, + sharex: bool = False, + sharey: bool = False, + figsize: tuple[int, int] | None = None, + layout: tuple[int, int] | None = None, + bins: int | Sequence[int] = 10, + backend: str | None = None, + legend: bool = False, + **kwargs, + ): + result = self._op_via_apply( + "hist", + column=column, + by=by, + grid=grid, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + ax=ax, + sharex=sharex, + sharey=sharey, + figsize=figsize, + layout=layout, + bins=bins, + backend=backend, + legend=legend, + **kwargs, + ) + return result + + # Decorated property not supported - https://github.com/python/mypy/issues/1362 + @property # type: ignore[misc] + @doc(DataFrame.dtypes.__doc__) + def dtypes(self) -> Series: + result = self._op_via_apply("dtypes") + return result + + @doc(DataFrame.corrwith.__doc__) + def corrwith( + self, + other: DataFrame | Series, + axis: Axis = 0, + drop: bool = False, + method: Literal["pearson", "kendall", "spearman"] + | Callable[[np.ndarray, np.ndarray], float] = "pearson", + numeric_only: bool | lib.NoDefault = lib.no_default, + ) -> DataFrame: + result = self._op_via_apply( + "corrwith", + other=other, + axis=axis, + drop=drop, + method=method, + numeric_only=numeric_only, + ) + return result + def _wrap_transform_general_frame( obj: DataFrame, group: DataFrame, res: DataFrame | Series diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 89c9f3701a424..a22774f8a2232 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -16,7 +16,6 @@ class providing the base-class of operations. ) import inspect from textwrap import dedent -import types from typing import ( TYPE_CHECKING, Callable, @@ -626,7 +625,6 @@ def f(self): class BaseGroupBy(PandasObject, SelectionMixin[NDFrameT], GroupByIndexingMixin): _group_selection: IndexLabel | None = None - _apply_allowlist: frozenset[str] = frozenset() _hidden_attrs = PandasObject._hidden_attrs | { "as_index", "axis", @@ -750,7 +748,7 @@ def _selected_obj(self): @final def _dir_additions(self) -> set[str]: - return self.obj._dir_additions() | self._apply_allowlist + return self.obj._dir_additions() @Substitution( klass="GroupBy", @@ -783,8 +781,6 @@ def pipe( ) -> T: return com.pipe(self, func, *args, **kwargs) - plot = property(GroupByPlot) - @final def get_group(self, name, obj=None) -> DataFrame | Series: """ @@ -992,75 +988,66 @@ def __getattribute__(self, attr: str): return super().__getattribute__(attr) @final - def _make_wrapper(self, name: str) -> Callable: - assert name in self._apply_allowlist - + def _op_via_apply(self, name: str, *args, **kwargs): + """Compute the result of an operation by using GroupBy's apply.""" + f = getattr(type(self._obj_with_exclusions), name) with self._group_selection_context(): # need to setup the selection # as are not passed directly but in the grouper - f = getattr(self._obj_with_exclusions, name) - if not isinstance(f, types.MethodType): - # error: Incompatible return value type - # (got "NDFrameT", expected "Callable[..., Any]") [return-value] - return cast(Callable, self.apply(lambda self: getattr(self, name))) + f = getattr(type(self._obj_with_exclusions), name) + if not callable(f): + return self.apply(lambda self: getattr(self, name)) - f = getattr(type(self._obj_with_exclusions), name) sig = inspect.signature(f) - def wrapper(*args, **kwargs): - # a little trickery for aggregation functions that need an axis - # argument - if "axis" in sig.parameters: - if kwargs.get("axis", None) is None: - kwargs["axis"] = self.axis - - numeric_only = kwargs.get("numeric_only", lib.no_default) + # a little trickery for aggregation functions that need an axis + # argument + if "axis" in sig.parameters: + if kwargs.get("axis", None) is None or kwargs.get("axis") is lib.no_default: + kwargs["axis"] = self.axis - def curried(x): - with warnings.catch_warnings(): - # Catch any warnings from dispatch to DataFrame; we'll emit - # a warning for groupby below - match = "The default value of numeric_only " - warnings.filterwarnings("ignore", match, FutureWarning) - return f(x, *args, **kwargs) + numeric_only = kwargs.get("numeric_only", lib.no_default) - # preserve the name so we can detect it when calling plot methods, - # to avoid duplicates - curried.__name__ = name - - # special case otherwise extra plots are created when catching the - # exception below - if name in base.plotting_methods: - return self.apply(curried) - - is_transform = name in base.transformation_kernels - - # Transform needs to keep the same schema, including when empty - if is_transform and self._obj_with_exclusions.empty: - return self._obj_with_exclusions - - result = self._python_apply_general( - curried, - self._obj_with_exclusions, - is_transform=is_transform, - not_indexed_same=not is_transform, - ) - - if self._selected_obj.ndim != 1 and self.axis != 1 and result.ndim != 1: - missing = self._obj_with_exclusions.columns.difference(result.columns) - if len(missing) > 0: - warn_dropping_nuisance_columns_deprecated( - type(self), name, numeric_only - ) + def curried(x): + with warnings.catch_warnings(): + # Catch any warnings from dispatch to DataFrame; we'll emit + # a warning for groupby below + match = "The default value of numeric_only " + warnings.filterwarnings("ignore", match, FutureWarning) + return f(x, *args, **kwargs) + + # preserve the name so we can detect it when calling plot methods, + # to avoid duplicates + curried.__name__ = name + + # special case otherwise extra plots are created when catching the + # exception below + if name in base.plotting_methods: + return self.apply(curried) + + is_transform = name in base.transformation_kernels + # Transform needs to keep the same schema, including when empty + if is_transform and self._obj_with_exclusions.empty: + return self._obj_with_exclusions + result = self._python_apply_general( + curried, + self._obj_with_exclusions, + is_transform=is_transform, + not_indexed_same=not is_transform, + ) - if self.grouper.has_dropped_na and is_transform: - # result will have dropped rows due to nans, fill with null - # and ensure index is ordered same as the input - result = self._set_result_index_ordered(result) - return result + if self._selected_obj.ndim != 1 and self.axis != 1 and result.ndim != 1: + missing = self._obj_with_exclusions.columns.difference(result.columns) + if len(missing) > 0: + warn_dropping_nuisance_columns_deprecated( + type(self), name, numeric_only + ) - wrapper.__name__ = name - return wrapper + if self.grouper.has_dropped_na and is_transform: + # result will have dropped rows due to nans, fill with null + # and ensure index is ordered same as the input + result = self._set_result_index_ordered(result) + return result # ----------------------------------------------------------------- # Selection diff --git a/pandas/tests/groupby/test_allowlist.py b/pandas/tests/groupby/test_allowlist.py index e541abb368a02..b9a7bb271e948 100644 --- a/pandas/tests/groupby/test_allowlist.py +++ b/pandas/tests/groupby/test_allowlist.py @@ -35,57 +35,6 @@ ] AGG_FUNCTIONS_WITH_SKIPNA = ["skew", "mad"] -df_allowlist = [ - "quantile", - "fillna", - "mad", - "take", - "idxmax", - "idxmin", - "tshift", - "skew", - "plot", - "hist", - "dtypes", - "corrwith", - "corr", - "cov", - "diff", -] - - -@pytest.fixture(params=df_allowlist) -def df_allowlist_fixture(request): - return request.param - - -s_allowlist = [ - "quantile", - "fillna", - "mad", - "take", - "idxmax", - "idxmin", - "tshift", - "skew", - "plot", - "hist", - "dtype", - "corr", - "cov", - "diff", - "unique", - "nlargest", - "nsmallest", - "is_monotonic_increasing", - "is_monotonic_decreasing", -] - - -@pytest.fixture(params=s_allowlist) -def s_allowlist_fixture(request): - return request.param - @pytest.fixture def df(): @@ -113,54 +62,6 @@ def df_letters(): return df -@pytest.mark.parametrize("allowlist", [df_allowlist, s_allowlist]) -def test_groupby_allowlist(df_letters, allowlist): - df = df_letters - if allowlist == df_allowlist: - # dataframe - obj = df_letters - else: - obj = df_letters["floats"] - - gb = obj.groupby(df.letters) - - assert set(allowlist) == set(gb._apply_allowlist) - - -def check_allowlist(obj, df, m): - # check the obj for a particular allowlist m - - gb = obj.groupby(df.letters) - - f = getattr(type(gb), m) - - # name - try: - n = f.__name__ - except AttributeError: - return - assert n == m - - # qualname - try: - n = f.__qualname__ - except AttributeError: - return - assert n.endswith(m) - - -def test_groupby_series_allowlist(df_letters, s_allowlist_fixture): - m = s_allowlist_fixture - df = df_letters - check_allowlist(df.letters, df, m) - - -def test_groupby_frame_allowlist(df_letters, df_allowlist_fixture): - m = df_allowlist_fixture - df = df_letters - check_allowlist(df, df, m) - - @pytest.fixture def raw_frame(multiindex_dataframe_random_data): df = multiindex_dataframe_random_data diff --git a/pandas/tests/groupby/test_api_consistency.py b/pandas/tests/groupby/test_api_consistency.py new file mode 100644 index 0000000000000..1e82c2b6ac6e2 --- /dev/null +++ b/pandas/tests/groupby/test_api_consistency.py @@ -0,0 +1,136 @@ +""" +Test the consistency of the groupby API, both internally and with other pandas objects. +""" + +import inspect + +import pytest + +from pandas import ( + DataFrame, + Series, +) +from pandas.core.groupby.generic import ( + DataFrameGroupBy, + SeriesGroupBy, +) + + +def test_frame_consistency(request, groupby_func): + # GH#48028 + if groupby_func in ("first", "last"): + msg = "first and last are entirely different between frame and groupby" + request.node.add_marker(pytest.mark.xfail(reason=msg)) + if groupby_func in ("nth", "cumcount", "ngroup"): + msg = "DataFrame has no such method" + request.node.add_marker(pytest.mark.xfail(reason=msg)) + if groupby_func in ("size",): + msg = "Method is a property" + request.node.add_marker(pytest.mark.xfail(reason=msg)) + + frame_method = getattr(DataFrame, groupby_func) + gb_method = getattr(DataFrameGroupBy, groupby_func) + result = set(inspect.signature(gb_method).parameters) + expected = set(inspect.signature(frame_method).parameters) + + # Exclude certain arguments from result and expected depending on the operation + # Some of these may be purposeful inconsistencies between the APIs + exclude_expected, exclude_result = set(), set() + if groupby_func in ("any", "all"): + exclude_expected = {"kwargs", "bool_only", "level", "axis"} + elif groupby_func in ("count",): + exclude_expected = {"numeric_only", "level", "axis"} + elif groupby_func in ("nunique",): + exclude_expected = {"axis"} + elif groupby_func in ("max", "min"): + exclude_expected = {"axis", "kwargs", "level", "skipna"} + exclude_result = {"min_count", "engine", "engine_kwargs"} + elif groupby_func in ("mean", "std", "sum", "var"): + exclude_expected = {"axis", "kwargs", "level", "skipna"} + exclude_result = {"engine", "engine_kwargs"} + elif groupby_func in ("median", "prod", "sem"): + exclude_expected = {"axis", "kwargs", "level", "skipna"} + elif groupby_func in ("backfill", "bfill", "ffill", "pad"): + exclude_expected = {"downcast", "inplace", "axis"} + elif groupby_func in ("cummax", "cummin"): + exclude_expected = {"skipna", "args"} + exclude_result = {"numeric_only"} + elif groupby_func in ("cumprod", "cumsum"): + exclude_expected = {"skipna"} + elif groupby_func in ("pct_change",): + exclude_expected = {"kwargs"} + exclude_result = {"axis"} + elif groupby_func in ("rank",): + exclude_expected = {"numeric_only"} + elif groupby_func in ("quantile",): + exclude_expected = {"method", "axis"} + + # Ensure excluded arguments are actually in the signatures + assert result & exclude_result == exclude_result + assert expected & exclude_expected == exclude_expected + + result -= exclude_result + expected -= exclude_expected + assert result == expected + + +def test_series_consistency(request, groupby_func): + # GH#48028 + if groupby_func in ("first", "last"): + msg = "first and last are entirely different between Series and groupby" + request.node.add_marker(pytest.mark.xfail(reason=msg)) + if groupby_func in ("nth", "cumcount", "ngroup", "corrwith"): + msg = "Series has no such method" + request.node.add_marker(pytest.mark.xfail(reason=msg)) + if groupby_func in ("size",): + msg = "Method is a property" + request.node.add_marker(pytest.mark.xfail(reason=msg)) + + series_method = getattr(Series, groupby_func) + gb_method = getattr(SeriesGroupBy, groupby_func) + result = set(inspect.signature(gb_method).parameters) + expected = set(inspect.signature(series_method).parameters) + + # Exclude certain arguments from result and expected depending on the operation + # Some of these may be purposeful inconsistencies between the APIs + exclude_expected, exclude_result = set(), set() + if groupby_func in ("any", "all"): + exclude_expected = {"kwargs", "bool_only", "level", "axis"} + elif groupby_func in ("count",): + exclude_expected = {"level"} + elif groupby_func in ("tshift",): + exclude_expected = {"axis"} + elif groupby_func in ("diff",): + exclude_result = {"axis"} + elif groupby_func in ("max", "min"): + exclude_expected = {"axis", "kwargs", "level", "skipna"} + exclude_result = {"min_count", "engine", "engine_kwargs"} + elif groupby_func in ("mean", "std", "sum", "var"): + exclude_expected = {"axis", "kwargs", "level", "skipna"} + exclude_result = {"engine", "engine_kwargs"} + elif groupby_func in ("median", "prod", "sem"): + exclude_expected = {"axis", "kwargs", "level", "skipna"} + elif groupby_func in ("backfill", "bfill", "ffill", "pad"): + exclude_expected = {"downcast", "inplace", "axis"} + elif groupby_func in ("cummax", "cummin"): + exclude_expected = {"skipna", "args"} + exclude_result = {"numeric_only"} + elif groupby_func in ("cumprod", "cumsum"): + exclude_expected = {"skipna"} + elif groupby_func in ("pct_change",): + exclude_expected = {"kwargs"} + exclude_result = {"axis"} + elif groupby_func in ("rank",): + exclude_expected = {"numeric_only"} + elif groupby_func in ("idxmin", "idxmax"): + exclude_expected = {"args", "kwargs"} + elif groupby_func in ("quantile",): + exclude_result = {"numeric_only"} + + # Ensure excluded arguments are actually in the signatures + assert result & exclude_result == exclude_result + assert expected & exclude_expected == exclude_expected + + result -= exclude_result + expected -= exclude_expected + assert result == expected diff --git a/pandas/util/_decorators.py b/pandas/util/_decorators.py index 86c945f1321f5..5a9a109d43bf4 100644 --- a/pandas/util/_decorators.py +++ b/pandas/util/_decorators.py @@ -351,7 +351,7 @@ def wrapper(*args, **kwargs) -> Callable[..., Any]: return decorate -def doc(*docstrings: str | Callable, **params) -> Callable[[F], F]: +def doc(*docstrings: None | str | Callable, **params) -> Callable[[F], F]: """ A decorator take docstring templates, concatenate them and perform string substitution on it. @@ -364,7 +364,7 @@ def doc(*docstrings: str | Callable, **params) -> Callable[[F], F]: Parameters ---------- - *docstrings : str or callable + *docstrings : None, str, or callable The string / docstring / docstring template to be appended in order after default docstring under callable. **params @@ -378,6 +378,8 @@ def decorator(decorated: F) -> F: docstring_components.append(dedent(decorated.__doc__)) for docstring in docstrings: + if docstring is None: + continue if hasattr(docstring, "_docstring_components"): # error: Item "str" of "Union[str, Callable[..., Any]]" has no attribute # "_docstring_components" @@ -389,13 +391,19 @@ def decorator(decorated: F) -> F: elif isinstance(docstring, str) or docstring.__doc__: docstring_components.append(docstring) - # formatting templates and concatenating docstring + params_applied = [ + component.format(**params) + if isinstance(component, str) and len(params) > 0 + else component + for component in docstring_components + ] + decorated.__doc__ = "".join( [ - component.format(**params) + component if isinstance(component, str) else dedent(component.__doc__ or "") - for component in docstring_components + for component in params_applied ] )