-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
CLN: type annotations in groupby.grouper, groupby.ops #29456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
98b53d7
efd4a9b
1933277
9b6a87a
d52add4
dee81f6
a7e6ad1
59cdf0a
6966fba
f038302
dc250f1
0b28143
1dfd414
6d3d485
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
) | ||
from pandas.core.dtypes.missing import _maybe_fill, isna | ||
|
||
from pandas._typing import FrameOrSeries | ||
import pandas.core.algorithms as algorithms | ||
from pandas.core.base import SelectionMixin | ||
import pandas.core.common as com | ||
|
@@ -89,12 +90,16 @@ def __init__( | |
|
||
self._filter_empty_groups = self.compressed = len(groupings) != 1 | ||
self.axis = axis | ||
self.groupings = groupings # type: Sequence[grouper.Grouping] | ||
self._groupings = list(groupings) # type: List[grouper.Grouping] | ||
self.sort = sort | ||
self.group_keys = group_keys | ||
self.mutated = mutated | ||
self.indexer = indexer | ||
|
||
@property | ||
def groupings(self) -> List["grouper.Grouping"]: | ||
return self._groupings | ||
|
||
@property | ||
def shape(self): | ||
return tuple(ping.ngroups for ping in self.groupings) | ||
|
@@ -106,7 +111,7 @@ def __iter__(self): | |
def nkeys(self) -> int: | ||
return len(self.groupings) | ||
|
||
def get_iterator(self, data, axis=0): | ||
def get_iterator(self, data: FrameOrSeries, axis: int = 0): | ||
""" | ||
Groupby iterator | ||
|
||
|
@@ -120,7 +125,7 @@ def get_iterator(self, data, axis=0): | |
for key, (i, group) in zip(keys, splitter): | ||
yield key, group | ||
|
||
def _get_splitter(self, data, axis=0): | ||
def _get_splitter(self, data: FrameOrSeries, axis: int = 0) -> "DataSplitter": | ||
comp_ids, _, ngroups = self.group_info | ||
return get_splitter(data, comp_ids, ngroups, axis=axis) | ||
|
||
|
@@ -142,13 +147,13 @@ def _get_group_keys(self): | |
# provide "flattened" iterator for multi-group setting | ||
return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes) | ||
|
||
def apply(self, f, data, axis: int = 0): | ||
def apply(self, f, data: FrameOrSeries, axis: int = 0): | ||
mutated = self.mutated | ||
splitter = self._get_splitter(data, axis=axis) | ||
group_keys = self._get_group_keys() | ||
result_values = None | ||
|
||
sdata = splitter._get_sorted_data() | ||
sdata = splitter._get_sorted_data() # type: FrameOrSeries | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this needed? shouldn't need to add a type annotation here. maybe the return type of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _get_sorted_data return type is annotated, but mypy complains without this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can update to py3.6 syntax in a followon There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no longer needed after e6c5f5a |
||
if sdata.ndim == 2 and np.any(sdata.dtypes.apply(is_extension_array_dtype)): | ||
# calling splitter.fast_apply will raise TypeError via apply_frame_axis0 | ||
# if we pass EA instead of ndarray | ||
|
@@ -157,7 +162,7 @@ def apply(self, f, data, axis: int = 0): | |
|
||
elif ( | ||
com.get_callable_name(f) not in base.plotting_methods | ||
and hasattr(splitter, "fast_apply") | ||
and isinstance(splitter, FrameSplitter) | ||
and axis == 0 | ||
# with MultiIndex, apply_frame_axis0 would raise InvalidApply | ||
# TODO: can we make this check prettier? | ||
|
@@ -229,8 +234,7 @@ def names(self): | |
|
||
def size(self) -> Series: | ||
""" | ||
Compute group sizes | ||
|
||
Compute group sizes. | ||
""" | ||
ids, _, ngroup = self.group_info | ||
ids = ensure_platform_int(ids) | ||
|
@@ -292,7 +296,7 @@ def reconstructed_codes(self) -> List[np.ndarray]: | |
return decons_obs_group_ids(comp_ids, obs_ids, self.shape, codes, xnull=True) | ||
|
||
@cache_readonly | ||
def result_index(self): | ||
def result_index(self) -> Index: | ||
if not self.compressed and len(self.groupings) == 1: | ||
return self.groupings[0].result_index.rename(self.names[0]) | ||
|
||
|
@@ -628,7 +632,7 @@ def agg_series(self, obj: Series, func): | |
raise | ||
return self._aggregate_series_pure_python(obj, func) | ||
|
||
def _aggregate_series_fast(self, obj, func): | ||
def _aggregate_series_fast(self, obj: Series, func): | ||
# At this point we have already checked that | ||
# - obj.index is not a MultiIndex | ||
# - obj is backed by an ndarray, not ExtensionArray | ||
|
@@ -646,7 +650,7 @@ def _aggregate_series_fast(self, obj, func): | |
result, counts = grouper.get_result() | ||
return result, counts | ||
|
||
def _aggregate_series_pure_python(self, obj, func): | ||
def _aggregate_series_pure_python(self, obj: Series, func): | ||
|
||
group_index, _, ngroups = self.group_info | ||
|
||
|
@@ -703,7 +707,12 @@ class BinGrouper(BaseGrouper): | |
""" | ||
|
||
def __init__( | ||
self, bins, binlabels, filter_empty=False, mutated=False, indexer=None | ||
self, | ||
bins, | ||
binlabels, | ||
filter_empty: bool = False, | ||
mutated: bool = False, | ||
indexer=None, | ||
): | ||
self.bins = ensure_int64(bins) | ||
self.binlabels = ensure_index(binlabels) | ||
|
@@ -737,7 +746,7 @@ def _get_grouper(self): | |
""" | ||
return self | ||
|
||
def get_iterator(self, data: NDFrame, axis: int = 0): | ||
def get_iterator(self, data: FrameOrSeries, axis: int = 0): | ||
""" | ||
Groupby iterator | ||
|
||
|
@@ -809,11 +818,9 @@ def names(self): | |
return [self.binlabels.name] | ||
|
||
@property | ||
def groupings(self): | ||
from pandas.core.groupby.grouper import Grouping | ||
|
||
def groupings(self) -> "List[grouper.Grouping]": | ||
return [ | ||
Grouping(lvl, lvl, in_axis=False, level=None, name=name) | ||
grouper.Grouping(lvl, lvl, in_axis=False, level=None, name=name) | ||
for lvl, name in zip(self.levels, self.names) | ||
] | ||
|
||
|
@@ -854,7 +861,7 @@ def _is_indexed_like(obj, axes) -> bool: | |
|
||
|
||
class DataSplitter: | ||
def __init__(self, data, labels, ngroups, axis: int = 0): | ||
def __init__(self, data: FrameOrSeries, labels, ngroups: int, axis: int = 0): | ||
self.data = data | ||
self.labels = ensure_int64(labels) | ||
self.ngroups = ngroups | ||
|
@@ -885,15 +892,15 @@ def __iter__(self): | |
for i, (start, end) in enumerate(zip(starts, ends)): | ||
yield i, self._chop(sdata, slice(start, end)) | ||
|
||
def _get_sorted_data(self): | ||
def _get_sorted_data(self) -> FrameOrSeries: | ||
return self.data.take(self.sort_idx, axis=self.axis) | ||
|
||
def _chop(self, sdata, slice_obj: slice): | ||
def _chop(self, sdata, slice_obj: slice) -> NDFrame: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is NDFrame used? is _chop not generic? should DataSplitter be a generic class? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont understand the question. Is "generic class" meaningfully different from "base class"? NDFrame is used because one subclass returns Series and the other returns DataFrame There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so looking at the definition of _chop in the derived classes, i'm guessing this abstractmethod should be typed as
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using FrameOrSeries here produces complaints:
I'm getting close to saying "screw it" when dealing with this type of error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mypy won't be looking at the derived classes when it performs type checking. it'll be looking at the type hints on the base class when it checks other methods in the base class. the abstractmethod should be generic since that is how the derived classes are typed Series -> Series and DataFrame -> DataFrame. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we are mixing a few different paradigms here. The subclasses should probably be annotated with the type respective to the class, rather than using the TypeVar, i.e. you would never parametrize a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
correct. but if a method of the base class is not overridden then the Series type in the derived class will become an NDFrame type after calling that method in the base class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There are some annotations in this PR that make it easier to reason about this code while reading it. The annotations in this sub-thread are not among them, so I do not particularly care about them. Let's focus for now on a minimal change needed to get this merged, as there are more bugfix PRs waiting in the wings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. e6c5f5a fixes this. |
||
raise AbstractMethodError(self) | ||
|
||
|
||
class SeriesSplitter(DataSplitter): | ||
def _chop(self, sdata, slice_obj: slice): | ||
def _chop(self, sdata: Series, slice_obj: slice) -> Series: | ||
return sdata._get_values(slice_obj) | ||
|
||
|
||
|
@@ -905,14 +912,14 @@ def fast_apply(self, f, names): | |
sdata = self._get_sorted_data() | ||
return libreduction.apply_frame_axis0(sdata, f, names, starts, ends) | ||
|
||
def _chop(self, sdata, slice_obj: slice): | ||
def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: | ||
simonjayhawkins marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.axis == 0: | ||
return sdata.iloc[slice_obj] | ||
else: | ||
return sdata._slice(slice_obj, axis=1) | ||
|
||
|
||
def get_splitter(data: NDFrame, *args, **kwargs): | ||
def get_splitter(data: FrameOrSeries, *args, **kwargs) -> DataSplitter: | ||
if isinstance(data, Series): | ||
klass = SeriesSplitter # type: Type[DataSplitter] | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for my own benefit in trying to reason about this code.