Skip to content

Commit d76f269

Browse files
jbrockmendelJulianWgs
authored andcommitted
CLN: more descriptive names, annotations in groupby (pandas-dev#41300)
1 parent ac13698 commit d76f269

File tree

4 files changed

+50
-27
lines changed

4 files changed

+50
-27
lines changed

pandas/core/groupby/generic.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,8 @@ def _aggregate_named(self, func, *args, **kwargs):
524524
for name, group in self:
525525
# Each step of this loop corresponds to
526526
# libreduction._BaseGrouper._apply_to_group
527-
group.name = name # NB: libreduction does not pin name
527+
# NB: libreduction does not pin name
528+
object.__setattr__(group, "name", name)
528529

529530
output = func(group, *args, **kwargs)
530531
output = libreduction.extract_result(output)
@@ -567,9 +568,9 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
567568
# Temporarily set observed for dealing with categoricals.
568569
with com.temp_setattr(self, "observed", True):
569570
result = getattr(self, func)(*args, **kwargs)
570-
return self._transform_fast(result)
571+
return self._wrap_transform_fast_result(result)
571572

572-
def _transform_general(self, func, *args, **kwargs):
573+
def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
573574
"""
574575
Transform with a callable func`.
575576
"""
@@ -599,7 +600,7 @@ def _transform_general(self, func, *args, **kwargs):
599600
result.name = self._selected_obj.name
600601
return result
601602

602-
def _transform_fast(self, result) -> Series:
603+
def _wrap_transform_fast_result(self, result: Series) -> Series:
603604
"""
604605
fast version of transform, only applicable to
605606
builtin/cythonizable functions
@@ -1436,11 +1437,11 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
14361437
if isinstance(result, DataFrame) and result.columns.equals(
14371438
self._obj_with_exclusions.columns
14381439
):
1439-
return self._transform_fast(result)
1440+
return self._wrap_transform_fast_result(result)
14401441

14411442
return self._transform_general(func, *args, **kwargs)
14421443

1443-
def _transform_fast(self, result: DataFrame) -> DataFrame:
1444+
def _wrap_transform_fast_result(self, result: DataFrame) -> DataFrame:
14441445
"""
14451446
Fast transform path for aggregations
14461447
"""
@@ -1449,14 +1450,9 @@ def _transform_fast(self, result: DataFrame) -> DataFrame:
14491450
# for each col, reshape to size of original frame by take operation
14501451
ids, _, _ = self.grouper.group_info
14511452
result = result.reindex(self.grouper.result_index, copy=False)
1452-
output = [
1453-
algorithms.take_nd(result.iloc[:, i].values, ids)
1454-
for i, _ in enumerate(result.columns)
1455-
]
1456-
1457-
return self.obj._constructor._from_arrays(
1458-
output, columns=result.columns, index=obj.index
1459-
)
1453+
output = result.take(ids, axis=0)
1454+
output.index = obj.index
1455+
return output
14601456

14611457
def _define_paths(self, func, *args, **kwargs):
14621458
if isinstance(func, str):
@@ -1653,7 +1649,7 @@ def _gotitem(self, key, ndim: int, subset=None):
16531649

16541650
raise AssertionError("invalid ndim for _gotitem")
16551651

1656-
def _wrap_frame_output(self, result, obj: DataFrame) -> DataFrame:
1652+
def _wrap_frame_output(self, result: dict, obj: DataFrame) -> DataFrame:
16571653
result_index = self.grouper.levels[0]
16581654

16591655
if self.axis == 0:

pandas/core/groupby/groupby.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class providing the base-class of operations.
1818
from textwrap import dedent
1919
import types
2020
from typing import (
21+
TYPE_CHECKING,
2122
Callable,
2223
Generic,
2324
Hashable,
@@ -106,6 +107,9 @@ class providing the base-class of operations.
106107

107108
from pandas.io.formats.format import repr_html_groupby
108109

110+
if TYPE_CHECKING:
111+
from typing import Literal
112+
109113
_common_see_also = """
110114
See Also
111115
--------
@@ -1994,7 +1998,7 @@ def ewm(self, *args, **kwargs):
19941998
)
19951999

19962000
@final
1997-
def _fill(self, direction, limit=None):
2001+
def _fill(self, direction: Literal["ffill", "bfill"], limit=None):
19982002
"""
19992003
Shared function for `pad` and `backfill` to call Cython method.
20002004
@@ -2736,7 +2740,7 @@ def _get_cythonized_result(
27362740
name = obj.name
27372741
values = obj._values
27382742

2739-
if numeric_only and not is_numeric_dtype(values):
2743+
if numeric_only and not is_numeric_dtype(values.dtype):
27402744
continue
27412745

27422746
if aggregate:

pandas/core/groupby/grouper.py

+31-10
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ class Grouper:
249249
Freq: 17T, dtype: int64
250250
"""
251251

252+
axis: int
253+
sort: bool
254+
dropna: bool
255+
252256
_attributes: tuple[str, ...] = ("key", "level", "freq", "axis", "sort")
253257

254258
def __new__(cls, *args, **kwargs):
@@ -260,7 +264,13 @@ def __new__(cls, *args, **kwargs):
260264
return super().__new__(cls)
261265

262266
def __init__(
263-
self, key=None, level=None, freq=None, axis=0, sort=False, dropna=True
267+
self,
268+
key=None,
269+
level=None,
270+
freq=None,
271+
axis: int = 0,
272+
sort: bool = False,
273+
dropna: bool = True,
264274
):
265275
self.key = key
266276
self.level = level
@@ -281,11 +291,11 @@ def __init__(
281291
def ax(self):
282292
return self.grouper
283293

284-
def _get_grouper(self, obj, validate: bool = True):
294+
def _get_grouper(self, obj: FrameOrSeries, validate: bool = True):
285295
"""
286296
Parameters
287297
----------
288-
obj : the subject object
298+
obj : Series or DataFrame
289299
validate : bool, default True
290300
if True, validate the grouper
291301
@@ -296,7 +306,9 @@ def _get_grouper(self, obj, validate: bool = True):
296306
self._set_grouper(obj)
297307
# error: Value of type variable "FrameOrSeries" of "get_grouper" cannot be
298308
# "Optional[Any]"
299-
self.grouper, _, self.obj = get_grouper( # type: ignore[type-var]
309+
# error: Incompatible types in assignment (expression has type "BaseGrouper",
310+
# variable has type "None")
311+
self.grouper, _, self.obj = get_grouper( # type: ignore[type-var,assignment]
300312
self.obj,
301313
[self.key],
302314
axis=self.axis,
@@ -375,15 +387,19 @@ def _set_grouper(self, obj: FrameOrSeries, sort: bool = False):
375387
ax = ax.take(indexer)
376388
obj = obj.take(indexer, axis=self.axis)
377389

378-
self.obj = obj
379-
self.grouper = ax
390+
# error: Incompatible types in assignment (expression has type
391+
# "FrameOrSeries", variable has type "None")
392+
self.obj = obj # type: ignore[assignment]
393+
# error: Incompatible types in assignment (expression has type "Index",
394+
# variable has type "None")
395+
self.grouper = ax # type: ignore[assignment]
380396
return self.grouper
381397

382398
@final
383399
@property
384400
def groups(self):
385-
# error: Item "None" of "Optional[Any]" has no attribute "groups"
386-
return self.grouper.groups # type: ignore[union-attr]
401+
# error: "None" has no attribute "groups"
402+
return self.grouper.groups # type: ignore[attr-defined]
387403

388404
@final
389405
def __repr__(self) -> str:
@@ -428,7 +444,7 @@ def __init__(
428444
index: Index,
429445
grouper=None,
430446
obj: FrameOrSeries | None = None,
431-
name=None,
447+
name: Hashable = None,
432448
level=None,
433449
sort: bool = True,
434450
observed: bool = False,
@@ -478,7 +494,12 @@ def __init__(
478494
# what key/level refer to exactly, don't need to
479495
# check again as we have by this point converted these
480496
# to an actual value (rather than a pd.Grouper)
481-
_, grouper, _ = self.grouper._get_grouper(self.obj, validate=False)
497+
_, grouper, _ = self.grouper._get_grouper(
498+
# error: Value of type variable "FrameOrSeries" of "_get_grouper"
499+
# of "Grouper" cannot be "Optional[FrameOrSeries]"
500+
self.obj, # type: ignore[type-var]
501+
validate=False,
502+
)
482503
if self.name is None:
483504
self.name = grouper.result_index.name
484505
self.obj = self.grouper.obj

pandas/core/groupby/ops.py

+2
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,8 @@ class BaseGrouper:
662662
663663
"""
664664

665+
axis: Index
666+
665667
def __init__(
666668
self,
667669
axis: Index,

0 commit comments

Comments
 (0)