Skip to content

CLN: type annotations in groupby.grouper, groupby.ops #29456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions pandas/core/groupby/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(self, key=None, level=None, freq=None, axis=0, sort=False):
def ax(self):
return self.grouper

def _get_grouper(self, obj, validate=True):
def _get_grouper(self, obj, validate: bool = True):
"""
Parameters
----------
Expand All @@ -143,17 +143,18 @@ def _get_grouper(self, obj, validate=True):
)
return self.binner, self.grouper, self.obj

def _set_grouper(self, obj, sort=False):
def _set_grouper(self, obj: FrameOrSeries, sort: bool = False):
"""
given an object and the specifications, setup the internal grouper
for this particular specification

Parameters
----------
obj : the subject object
obj : Series or DataFrame
sort : bool, default False
whether the resulting grouper should be sorted
"""
assert obj is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for my own benefit in trying to reason about this code.


if self.key is not None and self.level is not None:
raise ValueError("The Grouper cannot specify both a key and a level!")
Expand Down Expand Up @@ -211,13 +212,13 @@ def groups(self):

def __repr__(self) -> str:
attrs_list = (
"{}={!r}".format(attr_name, getattr(self, attr_name))
"{name}={val!r}".format(name=attr_name, val=getattr(self, attr_name))
for attr_name in self._attributes
if getattr(self, attr_name) is not None
)
attrs = ", ".join(attrs_list)
cls_name = self.__class__.__name__
return "{}({})".format(cls_name, attrs)
return "{cls}({attrs})".format(cls=cls_name, attrs=attrs)


class Grouping:
Expand Down Expand Up @@ -372,7 +373,7 @@ def __init__(
self.grouper = self.grouper.astype("timedelta64[ns]")

def __repr__(self) -> str:
return "Grouping({0})".format(self.name)
return "Grouping({name})".format(name=self.name)

def __iter__(self):
return iter(self.indices)
Expand Down Expand Up @@ -433,10 +434,10 @@ def get_grouper(
key=None,
axis: int = 0,
level=None,
sort=True,
observed=False,
mutated=False,
validate=True,
sort: bool = True,
observed: bool = False,
mutated: bool = False,
validate: bool = True,
) -> Tuple[BaseGrouper, List[Hashable], FrameOrSeries]:
"""
Create and return a BaseGrouper, which is an internal
Expand Down Expand Up @@ -670,7 +671,7 @@ def is_in_obj(gpr) -> bool:
return grouper, exclusions, obj


def _is_label_like(val):
def _is_label_like(val) -> bool:
return isinstance(val, (str, tuple)) or (val is not None and is_scalar(val))


Expand Down
53 changes: 30 additions & 23 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from pandas.core.dtypes.missing import _maybe_fill, isna

from pandas._typing import FrameOrSeries
import pandas.core.algorithms as algorithms
from pandas.core.base import SelectionMixin
import pandas.core.common as com
Expand Down Expand Up @@ -89,12 +90,16 @@ def __init__(

self._filter_empty_groups = self.compressed = len(groupings) != 1
self.axis = axis
self.groupings = groupings # type: Sequence[grouper.Grouping]
self._groupings = list(groupings) # type: List[grouper.Grouping]
self.sort = sort
self.group_keys = group_keys
self.mutated = mutated
self.indexer = indexer

@property
def groupings(self) -> List["grouper.Grouping"]:
return self._groupings

@property
def shape(self):
return tuple(ping.ngroups for ping in self.groupings)
Expand All @@ -106,7 +111,7 @@ def __iter__(self):
def nkeys(self) -> int:
return len(self.groupings)

def get_iterator(self, data, axis=0):
def get_iterator(self, data: FrameOrSeries, axis: int = 0):
"""
Groupby iterator

Expand All @@ -120,7 +125,7 @@ def get_iterator(self, data, axis=0):
for key, (i, group) in zip(keys, splitter):
yield key, group

def _get_splitter(self, data, axis=0):
def _get_splitter(self, data: FrameOrSeries, axis: int = 0) -> "DataSplitter":
comp_ids, _, ngroups = self.group_info
return get_splitter(data, comp_ids, ngroups, axis=axis)

Expand All @@ -142,13 +147,13 @@ def _get_group_keys(self):
# provide "flattened" iterator for multi-group setting
return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes)

def apply(self, f, data, axis: int = 0):
def apply(self, f, data: FrameOrSeries, axis: int = 0):
mutated = self.mutated
splitter = self._get_splitter(data, axis=axis)
group_keys = self._get_group_keys()
result_values = None

sdata = splitter._get_sorted_data()
sdata = splitter._get_sorted_data() # type: FrameOrSeries
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

shouldn't need to add a type annotation here. maybe the return type of _get_sorted_data needs to be added.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_sorted_data return type is annotated, but mypy complains without this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can update to py3.6 syntax in a followon

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer needed after e6c5f5a

if sdata.ndim == 2 and np.any(sdata.dtypes.apply(is_extension_array_dtype)):
# calling splitter.fast_apply will raise TypeError via apply_frame_axis0
# if we pass EA instead of ndarray
Expand All @@ -157,7 +162,7 @@ def apply(self, f, data, axis: int = 0):

elif (
com.get_callable_name(f) not in base.plotting_methods
and hasattr(splitter, "fast_apply")
and isinstance(splitter, FrameSplitter)
and axis == 0
# with MultiIndex, apply_frame_axis0 would raise InvalidApply
# TODO: can we make this check prettier?
Expand Down Expand Up @@ -229,8 +234,7 @@ def names(self):

def size(self) -> Series:
"""
Compute group sizes

Compute group sizes.
"""
ids, _, ngroup = self.group_info
ids = ensure_platform_int(ids)
Expand Down Expand Up @@ -292,7 +296,7 @@ def reconstructed_codes(self) -> List[np.ndarray]:
return decons_obs_group_ids(comp_ids, obs_ids, self.shape, codes, xnull=True)

@cache_readonly
def result_index(self):
def result_index(self) -> Index:
if not self.compressed and len(self.groupings) == 1:
return self.groupings[0].result_index.rename(self.names[0])

Expand Down Expand Up @@ -628,7 +632,7 @@ def agg_series(self, obj: Series, func):
raise
return self._aggregate_series_pure_python(obj, func)

def _aggregate_series_fast(self, obj, func):
def _aggregate_series_fast(self, obj: Series, func):
# At this point we have already checked that
# - obj.index is not a MultiIndex
# - obj is backed by an ndarray, not ExtensionArray
Expand All @@ -646,7 +650,7 @@ def _aggregate_series_fast(self, obj, func):
result, counts = grouper.get_result()
return result, counts

def _aggregate_series_pure_python(self, obj, func):
def _aggregate_series_pure_python(self, obj: Series, func):

group_index, _, ngroups = self.group_info

Expand Down Expand Up @@ -703,7 +707,12 @@ class BinGrouper(BaseGrouper):
"""

def __init__(
self, bins, binlabels, filter_empty=False, mutated=False, indexer=None
self,
bins,
binlabels,
filter_empty: bool = False,
mutated: bool = False,
indexer=None,
):
self.bins = ensure_int64(bins)
self.binlabels = ensure_index(binlabels)
Expand Down Expand Up @@ -737,7 +746,7 @@ def _get_grouper(self):
"""
return self

def get_iterator(self, data: NDFrame, axis: int = 0):
def get_iterator(self, data: FrameOrSeries, axis: int = 0):
"""
Groupby iterator

Expand Down Expand Up @@ -809,11 +818,9 @@ def names(self):
return [self.binlabels.name]

@property
def groupings(self):
from pandas.core.groupby.grouper import Grouping

def groupings(self) -> "List[grouper.Grouping]":
return [
Grouping(lvl, lvl, in_axis=False, level=None, name=name)
grouper.Grouping(lvl, lvl, in_axis=False, level=None, name=name)
for lvl, name in zip(self.levels, self.names)
]

Expand Down Expand Up @@ -854,7 +861,7 @@ def _is_indexed_like(obj, axes) -> bool:


class DataSplitter:
def __init__(self, data, labels, ngroups, axis: int = 0):
def __init__(self, data: FrameOrSeries, labels, ngroups: int, axis: int = 0):
self.data = data
self.labels = ensure_int64(labels)
self.ngroups = ngroups
Expand Down Expand Up @@ -885,15 +892,15 @@ def __iter__(self):
for i, (start, end) in enumerate(zip(starts, ends)):
yield i, self._chop(sdata, slice(start, end))

def _get_sorted_data(self):
def _get_sorted_data(self) -> FrameOrSeries:
return self.data.take(self.sort_idx, axis=self.axis)

def _chop(self, sdata, slice_obj: slice):
def _chop(self, sdata, slice_obj: slice) -> NDFrame:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is NDFrame used? is _chop not generic? should DataSplitter be a generic class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont understand the question. Is "generic class" meaningfully different from "base class"? NDFrame is used because one subclass returns Series and the other returns DataFrame

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataSplitter.__init__ accepts FrameOrSeries. do we need to persist this type thoughout the class. i.e. make DataSplitter a generic class. see https://mypy.readthedocs.io/en/latest/generics.html#defining-generic-classes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so looking at the definition of _chop in the derived classes, i'm guessing this abstractmethod should be typed as

 def _chop(self, sdata: FrameOrSeries, slice_obj: slice) -> FrameOrSeries:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using FrameOrSeries here produces complaints:

pandas/core/groupby/ops.py:879: error: Argument 1 of "_chop" is incompatible with supertype "DataSplitter"; supertype defines the argument type as "FrameOrSeries"
pandas/core/groupby/ops.py:879: error: Return type "Series" of "_chop" incompatible with return type "FrameOrSeries" in supertype "DataSplitter"
pandas/core/groupby/ops.py:891: error: Argument 1 of "_chop" is incompatible with supertype "DataSplitter"; supertype defines the argument type as "FrameOrSeries"
pandas/core/groupby/ops.py:891: error: Return type "DataFrame" of "_chop" incompatible with return type "FrameOrSeries" in supertype "DataSplitter"

I'm getting close to saying "screw it" when dealing with this type of error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy won't be looking at the derived classes when it performs type checking. it'll be looking at the type hints on the base class when it checks other methods in the base class.

the abstractmethod should be generic since that is how the derived classes are typed Series -> Series and DataFrame -> DataFrame.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are mixing a few different paradigms here. The subclasses should probably be annotated with the type respective to the class, rather than using the TypeVar, i.e. you would never parametrize a SeriesSplitter with a DataFrame - it exclusively deals with Series objects

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you would never parametrize a SeriesSplitter with a DataFrame - it exclusively deals with Series objects

correct. but if a method of the base class is not overridden then the Series type in the derived class will become an NDFrame type after calling that method in the base class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are mixing a few different paradigms here.

There are some annotations in this PR that make it easier to reason about this code while reading it. The annotations in this sub-thread are not among them, so I do not particularly care about them. Let's focus for now on a minimal change needed to get this merged, as there are more bugfix PRs waiting in the wings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e6c5f5a fixes this.

raise AbstractMethodError(self)


class SeriesSplitter(DataSplitter):
def _chop(self, sdata, slice_obj: slice):
def _chop(self, sdata: Series, slice_obj: slice) -> Series:
return sdata._get_values(slice_obj)


Expand All @@ -905,14 +912,14 @@ def fast_apply(self, f, names):
sdata = self._get_sorted_data()
return libreduction.apply_frame_axis0(sdata, f, names, starts, ends)

def _chop(self, sdata, slice_obj: slice):
def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
if self.axis == 0:
return sdata.iloc[slice_obj]
else:
return sdata._slice(slice_obj, axis=1)


def get_splitter(data: NDFrame, *args, **kwargs):
def get_splitter(data: FrameOrSeries, *args, **kwargs) -> DataSplitter:
if isinstance(data, Series):
klass = SeriesSplitter # type: Type[DataSplitter]
else:
Expand Down