Skip to content

Commit 2f26003

Browse files
committed
Add types to more top-level funcs
1 parent 6322b8f commit 2f26003

File tree

8 files changed

+35
-17
lines changed

8 files changed

+35
-17
lines changed

pandas/core/algorithms.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
intended for public consumption
44
"""
55
from textwrap import dedent
6-
from typing import Dict, Optional, Tuple, Union
6+
from typing import Dict, Optional, TYPE_CHECKING, Tuple, Union
77
from warnings import catch_warnings, simplefilter, warn
88

99
import numpy as np
@@ -50,6 +50,9 @@
5050
from pandas.core.construction import array, extract_array
5151
from pandas.core.indexers import validate_indices
5252

53+
if TYPE_CHECKING:
54+
from pandas import Series
55+
5356
_shared_docs: Dict[str, str] = {}
5457

5558

@@ -651,7 +654,7 @@ def value_counts(
651654
normalize: bool = False,
652655
bins=None,
653656
dropna: bool = True,
654-
) -> ABCSeries:
657+
) -> "Series":
655658
"""
656659
Compute a histogram of the counts of non-null values.
657660
@@ -793,7 +796,7 @@ def duplicated(values, keep="first") -> np.ndarray:
793796
return f(values, keep=keep)
794797

795798

796-
def mode(values, dropna: bool = True) -> ABCSeries:
799+
def mode(values, dropna: bool = True) -> "Series":
797800
"""
798801
Returns the mode(s) of an array.
799802

pandas/core/frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5706,7 +5706,7 @@ def update(
57065706

57075707
@Substitution("")
57085708
@Appender(_shared_docs["pivot"])
5709-
def pivot(self, index=None, columns=None, values=None):
5709+
def pivot(self, index=None, columns=None, values=None) -> "DataFrame":
57105710
from pandas.core.reshape.pivot import pivot
57115711

57125712
return pivot(self, index=index, columns=columns, values=values)
@@ -5853,7 +5853,7 @@ def pivot_table(
58535853
dropna=True,
58545854
margins_name="All",
58555855
observed=False,
5856-
):
5856+
) -> "DataFrame":
58575857
from pandas.core.reshape.pivot import pivot_table
58585858

58595859
return pivot_table(

pandas/core/reshape/concat.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
concat routines
33
"""
44

5-
from typing import List
5+
from typing import List, Union
66

77
import numpy as np
88

@@ -37,7 +37,7 @@ def concat(
3737
verify_integrity: bool = False,
3838
sort: bool = False,
3939
copy: bool = True,
40-
):
40+
) -> Union["DataFrame", "Series"]:
4141
"""
4242
Concatenate pandas objects along a particular axis with optional set logic
4343
along the other axes.

pandas/core/reshape/melt.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def lreshape(data: DataFrame, groups, dropna: bool = True, label=None) -> DataFr
192192
return data._constructor(mdata, columns=id_cols + pivot_cols)
193193

194194

195-
def wide_to_long(df: DataFrame, stubnames, i, j, sep: str = "", suffix: str = r"\d+"):
195+
def wide_to_long(
196+
df: DataFrame, stubnames, i, j, sep: str = "", suffix: str = r"\d+"
197+
) -> DataFrame:
196198
r"""
197199
Wide panel to long format. Less flexible but more user-friendly than melt.
198200

pandas/core/reshape/pivot.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def pivot_table(
5858
pieces.append(table)
5959
keys.append(getattr(func, "__name__", func))
6060

61-
return concat(pieces, keys=keys, axis=1)
61+
result = concat(pieces, keys=keys, axis=1)
62+
assert isinstance(result, DataFrame)
63+
return result
6264

6365
keys = index + columns
6466

@@ -461,7 +463,7 @@ def crosstab(
461463
margins_name: str = "All",
462464
dropna: bool = True,
463465
normalize=False,
464-
):
466+
) -> "DataFrame":
465467
"""
466468
Compute a simple cross tabulation of two (or more) factors. By default
467469
computes a frequency table of the factors unless an array of values and an

pandas/core/reshape/reshape.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import partial
22
import itertools
3+
from typing import List
34

45
import numpy as np
56

@@ -755,7 +756,7 @@ def get_dummies(
755756
sparse=False,
756757
drop_first=False,
757758
dtype=None,
758-
):
759+
) -> "DataFrame":
759760
"""
760761
Convert categorical variable into dummy/indicator variables.
761762
@@ -899,7 +900,7 @@ def check_len(item, name):
899900

900901
if data_to_encode.shape == data.shape:
901902
# Encoding the entire df, do not prepend any dropped columns
902-
with_dummies = []
903+
with_dummies: List[DataFrame] = []
903904
elif columns is not None:
904905
# Encoding only cols specified in columns. Get all cols not in
905906
# columns to prepend to result.
@@ -921,7 +922,9 @@ def check_len(item, name):
921922
dtype=dtype,
922923
)
923924
with_dummies.append(dummy)
924-
result = concat(with_dummies, axis=1)
925+
concatted = concat(with_dummies, axis=1)
926+
assert isinstance(concatted, DataFrame)
927+
result = concatted
925928
else:
926929
result = _get_dummies_1d(
927930
data,

pandas/io/formats/format.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,9 @@ def _chk_truncate(self) -> None:
281281
series = series.iloc[:max_rows]
282282
else:
283283
row_num = max_rows // 2
284-
series = concat((series.iloc[:row_num], series.iloc[-row_num:]))
284+
concatted = concat((series.iloc[:row_num], series.iloc[-row_num:]))
285+
assert isinstance(concatted, Series)
286+
series = concatted
285287
self.tr_row_num = row_num
286288
else:
287289
self.tr_row_num = None
@@ -676,9 +678,11 @@ def _chk_truncate(self) -> None:
676678
col_num = max_cols
677679
else:
678680
col_num = max_cols_adj // 2
679-
frame = concat(
681+
concatted = concat(
680682
(frame.iloc[:, :col_num], frame.iloc[:, -col_num:]), axis=1
681683
)
684+
assert isinstance(concatted, DataFrame)
685+
frame = concatted
682686
# truncate formatter
683687
if isinstance(self.formatters, (list, tuple)):
684688
truncate_fmt = self.formatters
@@ -695,7 +699,9 @@ def _chk_truncate(self) -> None:
695699
frame = frame.iloc[:max_rows, :]
696700
else:
697701
row_num = max_rows_adj // 2
698-
frame = concat((frame.iloc[:row_num, :], frame.iloc[-row_num:, :]))
702+
concatted = concat((frame.iloc[:row_num, :], frame.iloc[-row_num:, :]))
703+
assert isinstance(concatted, DataFrame)
704+
frame = concatted
699705
self.tr_row_num = row_num
700706
else:
701707
self.tr_row_num = None

pandas/io/pytables.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4432,7 +4432,9 @@ def read(
44324432
if len(frames) == 1:
44334433
df = frames[0]
44344434
else:
4435-
df = concat(frames, axis=1)
4435+
concatted = concat(frames, axis=1)
4436+
assert isinstance(concatted, DataFrame)
4437+
df = concatted
44364438

44374439
selection = Selection(self, where=where, start=start, stop=stop)
44384440
# apply the selection filters & axis orderings

0 commit comments

Comments
 (0)