From c347b48ffd697d9a8e2c4f3c4e5c12ab0c28eb00 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Fri, 8 Nov 2019 10:12:29 +0000 Subject: [PATCH] TYPING: change to FrameOrSeries Alias in pandas._typing --- pandas/_typing.py | 4 ++-- pandas/core/groupby/groupby.py | 23 ++++++++++++++++++----- pandas/core/groupby/grouper.py | 7 +++---- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index 445eff9e19e47..df2d327af92a3 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -21,8 +21,8 @@ from pandas.core.arrays.base import ExtensionArray # noqa: F401 from pandas.core.dtypes.dtypes import ExtensionDtype # noqa: F401 from pandas.core.indexes.base import Index # noqa: F401 + from pandas.core.frame import DataFrame # noqa: F401 from pandas.core.series import Series # noqa: F401 - from pandas.core.generic import NDFrame # noqa: F401 AnyArrayLike = TypeVar("AnyArrayLike", "ExtensionArray", "Index", "Series", np.ndarray) @@ -31,7 +31,7 @@ Dtype = Union[str, np.dtype, "ExtensionDtype"] FilePathOrBuffer = Union[str, Path, IO[AnyStr]] -FrameOrSeries = TypeVar("FrameOrSeries", bound="NDFrame") +FrameOrSeries = TypeVar("FrameOrSeries", "DataFrame", "Series") Scalar = Union[str, int, float, bool] Axis = Union[str, int] Ordered = Optional[bool] diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index e73be29d5b104..8b20371de603b 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -14,7 +14,17 @@ class providing the base-class of operations. import inspect import re import types -from typing import FrozenSet, Hashable, Iterable, List, Optional, Tuple, Type, Union +from typing import ( + FrozenSet, + Generic, + Hashable, + Iterable, + List, + Optional, + Tuple, + Type, + Union, +) import numpy as np @@ -41,6 +51,7 @@ class providing the base-class of operations. ) from pandas.core.dtypes.missing import isna, notna +from pandas._typing import FrameOrSeries from pandas.core import nanops import pandas.core.algorithms as algorithms from pandas.core.arrays import Categorical, try_cast_to_ea @@ -336,13 +347,13 @@ def _group_selection_context(groupby): groupby._reset_group_selection() -class _GroupBy(PandasObject, SelectionMixin): +class _GroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]): _group_selection = None _apply_whitelist = frozenset() # type: FrozenSet[str] def __init__( self, - obj: NDFrame, + obj: FrameOrSeries, keys=None, axis: int = 0, level=None, @@ -391,7 +402,7 @@ def __init__( mutated=self.mutated, ) - self.obj = obj + self.obj = obj # type: FrameOrSeries self.axis = obj._get_axis_number(axis) self.grouper = grouper self.exclusions = set(exclusions) if exclusions else set() @@ -1662,7 +1673,9 @@ def backfill(self, limit=None): @Substitution(name="groupby") @Substitution(see_also=_common_see_also) - def nth(self, n: Union[int, List[int]], dropna: Optional[str] = None) -> DataFrame: + def nth( + self, n: Union[int, List[int]], dropna: Optional[str] = None + ) -> FrameOrSeries: """ Take the nth row from each group if n is an int, or a subset of rows if n is a list of ints. diff --git a/pandas/core/groupby/grouper.py b/pandas/core/groupby/grouper.py index 370abe75e1327..c950ccd6ed1e4 100644 --- a/pandas/core/groupby/grouper.py +++ b/pandas/core/groupby/grouper.py @@ -21,7 +21,7 @@ ) from pandas.core.dtypes.generic import ABCSeries -from pandas._typing import FrameOrSeries +from pandas._typing import FrameOrSeries, Union import pandas.core.algorithms as algorithms from pandas.core.arrays import Categorical, ExtensionArray import pandas.core.common as com @@ -249,7 +249,7 @@ def __init__( self, index: Index, grouper=None, - obj: Optional[FrameOrSeries] = None, + obj: Optional[Union[DataFrame, Series]] = None, name=None, level=None, sort: bool = True, @@ -570,8 +570,7 @@ def get_grouper( all_in_columns_index = all( g in obj.columns or g in obj.index.names for g in keys ) - else: - assert isinstance(obj, Series) + elif isinstance(obj, Series): all_in_columns_index = all(g in obj.index.names for g in keys) if not all_in_columns_index: