Skip to content

TYPING: change to FrameOrSeries Alias in pandas._typing #29480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from pandas.core.arrays.base import ExtensionArray # noqa: F401
from pandas.core.dtypes.dtypes import ExtensionDtype # noqa: F401
from pandas.core.indexes.base import Index # noqa: F401
from pandas.core.frame import DataFrame # noqa: F401
from pandas.core.series import Series # noqa: F401
from pandas.core.generic import NDFrame # noqa: F401


AnyArrayLike = TypeVar("AnyArrayLike", "ExtensionArray", "Index", "Series", np.ndarray)
Expand All @@ -31,7 +31,7 @@
Dtype = Union[str, np.dtype, "ExtensionDtype"]
FilePathOrBuffer = Union[str, Path, IO[AnyStr]]

FrameOrSeries = TypeVar("FrameOrSeries", bound="NDFrame")
FrameOrSeries = TypeVar("FrameOrSeries", "DataFrame", "Series")
Scalar = Union[str, int, float, bool]
Axis = Union[str, int]
Ordered = Optional[bool]
Expand Down
23 changes: 18 additions & 5 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@ class providing the base-class of operations.
import inspect
import re
import types
from typing import FrozenSet, Hashable, Iterable, List, Optional, Tuple, Type, Union
from typing import (
FrozenSet,
Generic,
Hashable,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
)

import numpy as np

Expand All @@ -41,6 +51,7 @@ class providing the base-class of operations.
)
from pandas.core.dtypes.missing import isna, notna

from pandas._typing import FrameOrSeries
from pandas.core import nanops
import pandas.core.algorithms as algorithms
from pandas.core.arrays import Categorical, try_cast_to_ea
Expand Down Expand Up @@ -336,13 +347,13 @@ def _group_selection_context(groupby):
groupby._reset_group_selection()


class _GroupBy(PandasObject, SelectionMixin):
class _GroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm not sure I would consider _GroupBy a generic class - is this required as part of this change or just experimenting?

_group_selection = None
_apply_whitelist = frozenset() # type: FrozenSet[str]

def __init__(
self,
obj: NDFrame,
obj: FrameOrSeries,
keys=None,
axis: int = 0,
level=None,
Expand Down Expand Up @@ -391,7 +402,7 @@ def __init__(
mutated=self.mutated,
)

self.obj = obj
self.obj = obj # type: FrameOrSeries
self.axis = obj._get_axis_number(axis)
self.grouper = grouper
self.exclusions = set(exclusions) if exclusions else set()
Expand Down Expand Up @@ -1662,7 +1673,9 @@ def backfill(self, limit=None):

@Substitution(name="groupby")
@Substitution(see_also=_common_see_also)
def nth(self, n: Union[int, List[int]], dropna: Optional[str] = None) -> DataFrame:
def nth(
self, n: Union[int, List[int]], dropna: Optional[str] = None
) -> FrameOrSeries:
"""
Take the nth row from each group if n is an int, or a subset of rows
if n is a list of ints.
Expand Down
7 changes: 3 additions & 4 deletions pandas/core/groupby/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from pandas.core.dtypes.generic import ABCSeries

from pandas._typing import FrameOrSeries
from pandas._typing import FrameOrSeries, Union
import pandas.core.algorithms as algorithms
from pandas.core.arrays import Categorical, ExtensionArray
import pandas.core.common as com
Expand Down Expand Up @@ -249,7 +249,7 @@ def __init__(
self,
index: Index,
grouper=None,
obj: Optional[FrameOrSeries] = None,
obj: Optional[Union[DataFrame, Series]] = None,
name=None,
level=None,
sort: bool = True,
Expand Down Expand Up @@ -570,8 +570,7 @@ def get_grouper(
all_in_columns_index = all(
g in obj.columns or g in obj.index.names for g in keys
)
else:
assert isinstance(obj, Series)
elif isinstance(obj, Series):
all_in_columns_index = all(g in obj.index.names for g in keys)

if not all_in_columns_index:
Expand Down