From 2ec81e20fc07c4fc2294950728e93de179c17d53 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 3 May 2021 18:36:16 -0700 Subject: [PATCH 1/2] CLN: more descriptive names, annotations in groupby --- pandas/core/groupby/generic.py | 15 +++++++------ pandas/core/groupby/groupby.py | 8 +++++-- pandas/core/groupby/grouper.py | 41 +++++++++++++++++++++++++--------- pandas/core/groupby/ops.py | 2 ++ 4 files changed, 47 insertions(+), 19 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 4f60660dfb499..f9d4efbd1b8e8 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -524,7 +524,8 @@ def _aggregate_named(self, func, *args, **kwargs): for name, group in self: # Each step of this loop corresponds to # libreduction._BaseGrouper._apply_to_group - group.name = name # NB: libreduction does not pin name + # NB: libreduction does not pin name + object.__setattr__(group, "name", name) output = func(group, *args, **kwargs) output = libreduction.extract_result(output) @@ -567,9 +568,9 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): # Temporarily set observed for dealing with categoricals. with com.temp_setattr(self, "observed", True): result = getattr(self, func)(*args, **kwargs) - return self._transform_fast(result) + return self._wrap_transform_fast_result(result) - def _transform_general(self, func, *args, **kwargs): + def _transform_general(self, func: Callable, *args, **kwargs) -> Series: """ Transform with a callable func`. """ @@ -599,7 +600,7 @@ def _transform_general(self, func, *args, **kwargs): result.name = self._selected_obj.name return result - def _transform_fast(self, result) -> Series: + def _wrap_transform_fast_result(self, result: Series) -> Series: """ fast version of transform, only applicable to builtin/cythonizable functions @@ -1436,11 +1437,11 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): if isinstance(result, DataFrame) and result.columns.equals( self._obj_with_exclusions.columns ): - return self._transform_fast(result) + return self._wrap_transform_fast_result(result) return self._transform_general(func, *args, **kwargs) - def _transform_fast(self, result: DataFrame) -> DataFrame: + def _wrap_transform_fast_result(self, result: DataFrame) -> DataFrame: """ Fast transform path for aggregations """ @@ -1653,7 +1654,7 @@ def _gotitem(self, key, ndim: int, subset=None): raise AssertionError("invalid ndim for _gotitem") - def _wrap_frame_output(self, result, obj: DataFrame) -> DataFrame: + def _wrap_frame_output(self, result: dict, obj: DataFrame) -> DataFrame: result_index = self.grouper.levels[0] if self.axis == 0: diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 620668dadc32d..4bcb1b5a19cb6 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -18,6 +18,7 @@ class providing the base-class of operations. from textwrap import dedent import types from typing import ( + TYPE_CHECKING, Callable, Generic, Hashable, @@ -104,6 +105,9 @@ class providing the base-class of operations. from pandas.core.sorting import get_group_index_sorter from pandas.core.util.numba_ import NUMBA_FUNC_CACHE +if TYPE_CHECKING: + from typing import Literal + _common_see_also = """ See Also -------- @@ -1989,7 +1993,7 @@ def ewm(self, *args, **kwargs): ) @final - def _fill(self, direction, limit=None): + def _fill(self, direction: Literal["ffill", "bfill"], limit=None): """ Shared function for `pad` and `backfill` to call Cython method. @@ -2731,7 +2735,7 @@ def _get_cythonized_result( name = obj.name values = obj._values - if numeric_only and not is_numeric_dtype(values): + if numeric_only and not is_numeric_dtype(values.dtype): continue if aggregate: diff --git a/pandas/core/groupby/grouper.py b/pandas/core/groupby/grouper.py index 151756b829a1d..f1762a2535ff7 100644 --- a/pandas/core/groupby/grouper.py +++ b/pandas/core/groupby/grouper.py @@ -249,6 +249,10 @@ class Grouper: Freq: 17T, dtype: int64 """ + axis: int + sort: bool + dropna: bool + _attributes: tuple[str, ...] = ("key", "level", "freq", "axis", "sort") def __new__(cls, *args, **kwargs): @@ -260,7 +264,13 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls) def __init__( - self, key=None, level=None, freq=None, axis=0, sort=False, dropna=True + self, + key=None, + level=None, + freq=None, + axis: int = 0, + sort: bool = False, + dropna: bool = True, ): self.key = key self.level = level @@ -281,11 +291,11 @@ def __init__( def ax(self): return self.grouper - def _get_grouper(self, obj, validate: bool = True): + def _get_grouper(self, obj: FrameOrSeries, validate: bool = True): """ Parameters ---------- - obj : the subject object + obj : Series or DataFrame validate : bool, default True if True, validate the grouper @@ -296,7 +306,9 @@ def _get_grouper(self, obj, validate: bool = True): self._set_grouper(obj) # error: Value of type variable "FrameOrSeries" of "get_grouper" cannot be # "Optional[Any]" - self.grouper, _, self.obj = get_grouper( # type: ignore[type-var] + # error: Incompatible types in assignment (expression has type "BaseGrouper", + # variable has type "None") + self.grouper, _, self.obj = get_grouper( # type: ignore[type-var,assignment] self.obj, [self.key], axis=self.axis, @@ -375,15 +387,19 @@ def _set_grouper(self, obj: FrameOrSeries, sort: bool = False): ax = ax.take(indexer) obj = obj.take(indexer, axis=self.axis) - self.obj = obj - self.grouper = ax + # error: Incompatible types in assignment (expression has type + # "FrameOrSeries", variable has type "None") + self.obj = obj # type: ignore[assignment] + # error: Incompatible types in assignment (expression has type "Index", + # variable has type "None") + self.grouper = ax # type: ignore[assignment] return self.grouper @final @property def groups(self): - # error: Item "None" of "Optional[Any]" has no attribute "groups" - return self.grouper.groups # type: ignore[union-attr] + # error: "None" has no attribute "groups" + return self.grouper.groups # type: ignore[attr-defined] @final def __repr__(self) -> str: @@ -428,7 +444,7 @@ def __init__( index: Index, grouper=None, obj: FrameOrSeries | None = None, - name=None, + name: Hashable = None, level=None, sort: bool = True, observed: bool = False, @@ -478,7 +494,12 @@ def __init__( # what key/level refer to exactly, don't need to # check again as we have by this point converted these # to an actual value (rather than a pd.Grouper) - _, grouper, _ = self.grouper._get_grouper(self.obj, validate=False) + _, grouper, _ = self.grouper._get_grouper( + # error: Value of type variable "FrameOrSeries" of "_get_grouper" + # of "Grouper" cannot be "Optional[FrameOrSeries]" + self.obj, # type: ignore[type-var] + validate=False, + ) if self.name is None: self.name = grouper.result_index.name self.obj = self.grouper.obj diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 975a902f49db9..e90892138f15a 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -662,6 +662,8 @@ class BaseGrouper: """ + axis: Index + def __init__( self, axis: Index, From d87a6ecc0be4183fc73dc3434383c03f1a5e6194 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 3 May 2021 18:59:23 -0700 Subject: [PATCH 2/2] REF: simplify _wrap_transform_fast_result --- pandas/core/groupby/generic.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index f9d4efbd1b8e8..324aef3cd5435 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1450,14 +1450,9 @@ def _wrap_transform_fast_result(self, result: DataFrame) -> DataFrame: # for each col, reshape to size of original frame by take operation ids, _, _ = self.grouper.group_info result = result.reindex(self.grouper.result_index, copy=False) - output = [ - algorithms.take_nd(result.iloc[:, i].values, ids) - for i, _ in enumerate(result.columns) - ] - - return self.obj._constructor._from_arrays( - output, columns=result.columns, index=obj.index - ) + output = result.take(ids, axis=0) + output.index = obj.index + return output def _define_paths(self, func, *args, **kwargs): if isinstance(func, str):