Skip to content

Commit 66b4dbc

Browse files
topper-123jreback
authored andcommitted
TYP: Add types to top-level funcs, step 2 (#30582)
1 parent 3253cb0 commit 66b4dbc

File tree

8 files changed

+75
-16
lines changed

8 files changed

+75
-16
lines changed

pandas/_typing.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from pandas.core.arrays.base import ExtensionArray # noqa: F401
2222
from pandas.core.dtypes.dtypes import ExtensionDtype # noqa: F401
2323
from pandas.core.indexes.base import Index # noqa: F401
24-
from pandas.core.series import Series # noqa: F401
2524
from pandas.core.generic import NDFrame # noqa: F401
2625
from pandas import Interval # noqa: F401
26+
from pandas.core.series import Series # noqa: F401
27+
from pandas.core.frame import DataFrame # noqa: F401
2728

2829
# array-like
2930

@@ -41,7 +42,19 @@
4142

4243
Dtype = Union[str, np.dtype, "ExtensionDtype"]
4344
FilePathOrBuffer = Union[str, Path, IO[AnyStr]]
45+
46+
# FrameOrSeriesUnion means either a DataFrame or a Series. E.g.
47+
# `def func(a: FrameOrSeriesUnion) -> FrameOrSeriesUnion: ...` means that if a Series
48+
# is passed in, either a Series or DataFrame is returned, and if a DataFrame is passed
49+
# in, either a DataFrame or a Series is returned.
50+
FrameOrSeriesUnion = Union["DataFrame", "Series"]
51+
52+
# FrameOrSeries is stricter and ensures that the same subclass of NDFrame always is
53+
# used. E.g. `def func(a: FrameOrSeries) -> FrameOrSeries: ...` means that if a
54+
# Series is passed into a function, a Series is always returned and if a DataFrame is
55+
# passed in, a DataFrame is always returned.
4456
FrameOrSeries = TypeVar("FrameOrSeries", bound="NDFrame")
57+
4558
Axis = Union[str, int]
4659
Ordered = Optional[bool]
4760
JSONSerializable = Union[PythonScalar, List, Dict]

pandas/core/algorithms.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
intended for public consumption
44
"""
55
from textwrap import dedent
6-
from typing import Dict, Optional, Tuple, Union
6+
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
77
from warnings import catch_warnings, simplefilter, warn
88

99
import numpy as np
@@ -50,6 +50,9 @@
5050
from pandas.core.construction import array, extract_array
5151
from pandas.core.indexers import validate_indices
5252

53+
if TYPE_CHECKING:
54+
from pandas import Series
55+
5356
_shared_docs: Dict[str, str] = {}
5457

5558

@@ -651,7 +654,7 @@ def value_counts(
651654
normalize: bool = False,
652655
bins=None,
653656
dropna: bool = True,
654-
) -> ABCSeries:
657+
) -> "Series":
655658
"""
656659
Compute a histogram of the counts of non-null values.
657660
@@ -793,7 +796,7 @@ def duplicated(values, keep="first") -> np.ndarray:
793796
return f(values, keep=keep)
794797

795798

796-
def mode(values, dropna: bool = True) -> ABCSeries:
799+
def mode(values, dropna: bool = True) -> "Series":
797800
"""
798801
Returns the mode(s) of an array.
799802

pandas/core/frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5878,7 +5878,7 @@ def groupby(
58785878

58795879
@Substitution("")
58805880
@Appender(_shared_docs["pivot"])
5881-
def pivot(self, index=None, columns=None, values=None):
5881+
def pivot(self, index=None, columns=None, values=None) -> "DataFrame":
58825882
from pandas.core.reshape.pivot import pivot
58835883

58845884
return pivot(self, index=index, columns=columns, values=values)
@@ -6025,7 +6025,7 @@ def pivot_table(
60256025
dropna=True,
60266026
margins_name="All",
60276027
observed=False,
6028-
):
6028+
) -> "DataFrame":
60296029
from pandas.core.reshape.pivot import pivot_table
60306030

60316031
return pivot_table(

pandas/core/reshape/concat.py

+41-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
concat routines
33
"""
44

5-
from typing import Hashable, List, Optional
5+
from typing import Hashable, List, Mapping, Optional, Sequence, Union, overload
66

77
import numpy as np
88

9+
from pandas._typing import FrameOrSeriesUnion
10+
911
from pandas import DataFrame, Index, MultiIndex, Series
1012
from pandas.core.arrays.categorical import (
1113
factorize_from_iterable,
@@ -26,8 +28,27 @@
2628
# Concatenate DataFrame objects
2729

2830

31+
@overload
32+
def concat(
33+
objs: Union[Sequence["DataFrame"], Mapping[Optional[Hashable], "DataFrame"]],
34+
axis=0,
35+
join: str = "outer",
36+
ignore_index: bool = False,
37+
keys=None,
38+
levels=None,
39+
names=None,
40+
verify_integrity: bool = False,
41+
sort: bool = False,
42+
copy: bool = True,
43+
) -> "DataFrame":
44+
...
45+
46+
47+
@overload
2948
def concat(
30-
objs,
49+
objs: Union[
50+
Sequence[FrameOrSeriesUnion], Mapping[Optional[Hashable], FrameOrSeriesUnion]
51+
],
3152
axis=0,
3253
join: str = "outer",
3354
ignore_index: bool = False,
@@ -37,7 +58,24 @@ def concat(
3758
verify_integrity: bool = False,
3859
sort: bool = False,
3960
copy: bool = True,
40-
):
61+
) -> FrameOrSeriesUnion:
62+
...
63+
64+
65+
def concat(
66+
objs: Union[
67+
Sequence[FrameOrSeriesUnion], Mapping[Optional[Hashable], FrameOrSeriesUnion]
68+
],
69+
axis=0,
70+
join="outer",
71+
ignore_index: bool = False,
72+
keys=None,
73+
levels=None,
74+
names=None,
75+
verify_integrity: bool = False,
76+
sort: bool = False,
77+
copy: bool = True,
78+
) -> FrameOrSeriesUnion:
4179
"""
4280
Concatenate pandas objects along a particular axis with optional set logic
4381
along the other axes.

pandas/core/reshape/melt.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def lreshape(data: DataFrame, groups, dropna: bool = True, label=None) -> DataFr
192192
return data._constructor(mdata, columns=id_cols + pivot_cols)
193193

194194

195-
def wide_to_long(df: DataFrame, stubnames, i, j, sep: str = "", suffix: str = r"\d+"):
195+
def wide_to_long(
196+
df: DataFrame, stubnames, i, j, sep: str = "", suffix: str = r"\d+"
197+
) -> DataFrame:
196198
r"""
197199
Wide panel to long format. Less flexible but more user-friendly than melt.
198200

pandas/core/reshape/pivot.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Callable, Dict, Tuple, Union
1+
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union
22

33
import numpy as np
44

@@ -40,7 +40,7 @@ def pivot_table(
4040
columns = _convert_by(columns)
4141

4242
if isinstance(aggfunc, list):
43-
pieces = []
43+
pieces: List[DataFrame] = []
4444
keys = []
4545
for func in aggfunc:
4646
table = pivot_table(
@@ -459,7 +459,7 @@ def crosstab(
459459
margins_name: str = "All",
460460
dropna: bool = True,
461461
normalize=False,
462-
):
462+
) -> "DataFrame":
463463
"""
464464
Compute a simple cross tabulation of two (or more) factors. By default
465465
computes a frequency table of the factors unless an array of values and an

pandas/core/reshape/reshape.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import partial
22
import itertools
3+
from typing import List
34

45
import numpy as np
56

@@ -755,7 +756,7 @@ def get_dummies(
755756
sparse=False,
756757
drop_first=False,
757758
dtype=None,
758-
):
759+
) -> "DataFrame":
759760
"""
760761
Convert categorical variable into dummy/indicator variables.
761762
@@ -899,7 +900,7 @@ def check_len(item, name):
899900

900901
if data_to_encode.shape == data.shape:
901902
# Encoding the entire df, do not prepend any dropped columns
902-
with_dummies = []
903+
with_dummies: List[DataFrame] = []
903904
elif columns is not None:
904905
# Encoding only cols specified in columns. Get all cols not in
905906
# columns to prepend to result.

pandas/io/formats/format.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,9 @@ def _chk_truncate(self) -> None:
281281
series = series.iloc[:max_rows]
282282
else:
283283
row_num = max_rows // 2
284-
series = concat((series.iloc[:row_num], series.iloc[-row_num:]))
284+
series = series._ensure_type(
285+
concat((series.iloc[:row_num], series.iloc[-row_num:]))
286+
)
285287
self.tr_row_num = row_num
286288
else:
287289
self.tr_row_num = None

0 commit comments

Comments
 (0)