diff --git a/pandas/_typing.py b/pandas/_typing.py index 7b89486751f12..14cf5157cea1d 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -21,9 +21,10 @@ from pandas.core.arrays.base import ExtensionArray # noqa: F401 from pandas.core.dtypes.dtypes import ExtensionDtype # noqa: F401 from pandas.core.indexes.base import Index # noqa: F401 - from pandas.core.series import Series # noqa: F401 from pandas.core.generic import NDFrame # noqa: F401 from pandas import Interval # noqa: F401 + from pandas.core.series import Series # noqa: F401 + from pandas.core.frame import DataFrame # noqa: F401 # array-like @@ -41,7 +42,19 @@ Dtype = Union[str, np.dtype, "ExtensionDtype"] FilePathOrBuffer = Union[str, Path, IO[AnyStr]] + +# FrameOrSeriesUnion means either a DataFrame or a Series. E.g. +# `def func(a: FrameOrSeriesUnion) -> FrameOrSeriesUnion: ...` means that if a Series +# is passed in, either a Series or DataFrame is returned, and if a DataFrame is passed +# in, either a DataFrame or a Series is returned. +FrameOrSeriesUnion = Union["DataFrame", "Series"] + +# FrameOrSeries is stricter and ensures that the same subclass of NDFrame always is +# used. E.g. `def func(a: FrameOrSeries) -> FrameOrSeries: ...` means that if a +# Series is passed into a function, a Series is always returned and if a DataFrame is +# passed in, a DataFrame is always returned. FrameOrSeries = TypeVar("FrameOrSeries", bound="NDFrame") + Axis = Union[str, int] Ordered = Optional[bool] JSONSerializable = Union[PythonScalar, List, Dict] diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 42cfd9d54ac19..39e8e9008a844 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -3,7 +3,7 @@ intended for public consumption """ from textwrap import dedent -from typing import Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from warnings import catch_warnings, simplefilter, warn import numpy as np @@ -50,6 +50,9 @@ from pandas.core.construction import array, extract_array from pandas.core.indexers import validate_indices +if TYPE_CHECKING: + from pandas import Series + _shared_docs: Dict[str, str] = {} @@ -651,7 +654,7 @@ def value_counts( normalize: bool = False, bins=None, dropna: bool = True, -) -> ABCSeries: +) -> "Series": """ Compute a histogram of the counts of non-null values. @@ -793,7 +796,7 @@ def duplicated(values, keep="first") -> np.ndarray: return f(values, keep=keep) -def mode(values, dropna: bool = True) -> ABCSeries: +def mode(values, dropna: bool = True) -> "Series": """ Returns the mode(s) of an array. diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 97b218878f4cc..ba0c0e7d66b1d 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5878,7 +5878,7 @@ def groupby( @Substitution("") @Appender(_shared_docs["pivot"]) - def pivot(self, index=None, columns=None, values=None): + def pivot(self, index=None, columns=None, values=None) -> "DataFrame": from pandas.core.reshape.pivot import pivot return pivot(self, index=index, columns=columns, values=values) @@ -6025,7 +6025,7 @@ def pivot_table( dropna=True, margins_name="All", observed=False, - ): + ) -> "DataFrame": from pandas.core.reshape.pivot import pivot_table return pivot_table( diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index 8886732fc8d79..2007f6aa32a57 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -2,10 +2,12 @@ concat routines """ -from typing import Hashable, List, Optional +from typing import Hashable, List, Mapping, Optional, Sequence, Union, overload import numpy as np +from pandas._typing import FrameOrSeriesUnion + from pandas import DataFrame, Index, MultiIndex, Series from pandas.core.arrays.categorical import ( factorize_from_iterable, @@ -26,8 +28,27 @@ # Concatenate DataFrame objects +@overload +def concat( + objs: Union[Sequence["DataFrame"], Mapping[Optional[Hashable], "DataFrame"]], + axis=0, + join: str = "outer", + ignore_index: bool = False, + keys=None, + levels=None, + names=None, + verify_integrity: bool = False, + sort: bool = False, + copy: bool = True, +) -> "DataFrame": + ... + + +@overload def concat( - objs, + objs: Union[ + Sequence[FrameOrSeriesUnion], Mapping[Optional[Hashable], FrameOrSeriesUnion] + ], axis=0, join: str = "outer", ignore_index: bool = False, @@ -37,7 +58,24 @@ def concat( verify_integrity: bool = False, sort: bool = False, copy: bool = True, -): +) -> FrameOrSeriesUnion: + ... + + +def concat( + objs: Union[ + Sequence[FrameOrSeriesUnion], Mapping[Optional[Hashable], FrameOrSeriesUnion] + ], + axis=0, + join="outer", + ignore_index: bool = False, + keys=None, + levels=None, + names=None, + verify_integrity: bool = False, + sort: bool = False, + copy: bool = True, +) -> FrameOrSeriesUnion: """ Concatenate pandas objects along a particular axis with optional set logic along the other axes. diff --git a/pandas/core/reshape/melt.py b/pandas/core/reshape/melt.py index 38bda94489d01..722dd8751dfad 100644 --- a/pandas/core/reshape/melt.py +++ b/pandas/core/reshape/melt.py @@ -192,7 +192,9 @@ def lreshape(data: DataFrame, groups, dropna: bool = True, label=None) -> DataFr return data._constructor(mdata, columns=id_cols + pivot_cols) -def wide_to_long(df: DataFrame, stubnames, i, j, sep: str = "", suffix: str = r"\d+"): +def wide_to_long( + df: DataFrame, stubnames, i, j, sep: str = "", suffix: str = r"\d+" +) -> DataFrame: r""" Wide panel to long format. Less flexible but more user-friendly than melt. diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index c544c132d6921..b443ba142369c 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, Dict, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union import numpy as np @@ -40,7 +40,7 @@ def pivot_table( columns = _convert_by(columns) if isinstance(aggfunc, list): - pieces = [] + pieces: List[DataFrame] = [] keys = [] for func in aggfunc: table = pivot_table( @@ -459,7 +459,7 @@ def crosstab( margins_name: str = "All", dropna: bool = True, normalize=False, -): +) -> "DataFrame": """ Compute a simple cross tabulation of two (or more) factors. By default computes a frequency table of the factors unless an array of values and an diff --git a/pandas/core/reshape/reshape.py b/pandas/core/reshape/reshape.py index 004bd0199eb58..da92e1154556a 100644 --- a/pandas/core/reshape/reshape.py +++ b/pandas/core/reshape/reshape.py @@ -1,5 +1,6 @@ from functools import partial import itertools +from typing import List import numpy as np @@ -755,7 +756,7 @@ def get_dummies( sparse=False, drop_first=False, dtype=None, -): +) -> "DataFrame": """ Convert categorical variable into dummy/indicator variables. @@ -899,7 +900,7 @@ def check_len(item, name): if data_to_encode.shape == data.shape: # Encoding the entire df, do not prepend any dropped columns - with_dummies = [] + with_dummies: List[DataFrame] = [] elif columns is not None: # Encoding only cols specified in columns. Get all cols not in # columns to prepend to result. diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index ea22999470102..3020ac421fc2f 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -281,7 +281,9 @@ def _chk_truncate(self) -> None: series = series.iloc[:max_rows] else: row_num = max_rows // 2 - series = concat((series.iloc[:row_num], series.iloc[-row_num:])) + series = series._ensure_type( + concat((series.iloc[:row_num], series.iloc[-row_num:])) + ) self.tr_row_num = row_num else: self.tr_row_num = None