Skip to content

Commit 88f9556

Browse files
committed
overload concat function
1 parent 61372ea commit 88f9556

File tree

5 files changed

+39
-19
lines changed

5 files changed

+39
-19
lines changed

pandas/core/reshape/concat.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
concat routines
33
"""
44

5-
from typing import List, Union
5+
from typing import List, Mapping, Sequence, Union, overload
66

77
import numpy as np
88

@@ -25,9 +25,12 @@
2525
# ---------------------------------------------------------------------
2626
# Concatenate DataFrame objects
2727

28+
FrameOrSeriesUnion = Union["DataFrame", "Series"]
2829

30+
31+
@overload
2932
def concat(
30-
objs,
33+
objs: Union[Sequence["DataFrame"], Mapping[str, "DataFrame"]],
3134
axis=0,
3235
join: str = "outer",
3336
ignore_index: bool = False,
@@ -37,6 +40,33 @@ def concat(
3740
verify_integrity: bool = False,
3841
sort: bool = False,
3942
copy: bool = True,
43+
) -> "DataFrame":
44+
...
45+
@overload # noqa: E302
46+
def concat(
47+
objs: Union[Sequence[FrameOrSeriesUnion], Mapping[str, FrameOrSeriesUnion]],
48+
axis=0,
49+
join: str = "outer",
50+
ignore_index: bool = False,
51+
keys=None,
52+
levels=None,
53+
names=None,
54+
verify_integrity: bool = False,
55+
sort: bool = False,
56+
copy: bool = True,
57+
) -> Union["DataFrame", "Series"]:
58+
...
59+
def concat( # noqa: E302
60+
objs: Union[Sequence[FrameOrSeriesUnion], Mapping[str, FrameOrSeriesUnion]],
61+
axis=0,
62+
join="outer",
63+
ignore_index: bool = False,
64+
keys=None,
65+
levels=None,
66+
names=None,
67+
verify_integrity: bool = False,
68+
sort: bool = False,
69+
copy: bool = True,
4070
) -> Union["DataFrame", "Series"]:
4171
"""
4272
Concatenate pandas objects along a particular axis with optional set logic

pandas/core/reshape/pivot.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Callable, Dict, Tuple, Union
1+
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union
22

33
import numpy as np
44

@@ -40,7 +40,7 @@ def pivot_table(
4040
columns = _convert_by(columns)
4141

4242
if isinstance(aggfunc, list):
43-
pieces = []
43+
pieces: List[DataFrame] = []
4444
keys = []
4545
for func in aggfunc:
4646
table = pivot_table(
@@ -58,9 +58,7 @@ def pivot_table(
5858
pieces.append(table)
5959
keys.append(getattr(func, "__name__", func))
6060

61-
result = concat(pieces, keys=keys, axis=1)
62-
assert isinstance(result, DataFrame)
63-
return result
61+
return concat(pieces, keys=keys, axis=1)
6462

6563
keys = index + columns
6664

pandas/core/reshape/reshape.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -922,9 +922,7 @@ def check_len(item, name):
922922
dtype=dtype,
923923
)
924924
with_dummies.append(dummy)
925-
concatted = concat(with_dummies, axis=1)
926-
assert isinstance(concatted, DataFrame)
927-
result = concatted
925+
result = concat(with_dummies, axis=1)
928926
else:
929927
result = _get_dummies_1d(
930928
data,

pandas/io/formats/format.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -678,11 +678,9 @@ def _chk_truncate(self) -> None:
678678
col_num = max_cols
679679
else:
680680
col_num = max_cols_adj // 2
681-
concatted = concat(
681+
frame = concat(
682682
(frame.iloc[:, :col_num], frame.iloc[:, -col_num:]), axis=1
683683
)
684-
assert isinstance(concatted, DataFrame)
685-
frame = concatted
686684
# truncate formatter
687685
if isinstance(self.formatters, (list, tuple)):
688686
truncate_fmt = self.formatters
@@ -699,9 +697,7 @@ def _chk_truncate(self) -> None:
699697
frame = frame.iloc[:max_rows, :]
700698
else:
701699
row_num = max_rows_adj // 2
702-
concatted = concat((frame.iloc[:row_num, :], frame.iloc[-row_num:, :]))
703-
assert isinstance(concatted, DataFrame)
704-
frame = concatted
700+
frame = concat((frame.iloc[:row_num, :], frame.iloc[-row_num:, :]))
705701
self.tr_row_num = row_num
706702
else:
707703
self.tr_row_num = None

pandas/io/pytables.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -4432,9 +4432,7 @@ def read(
44324432
if len(frames) == 1:
44334433
df = frames[0]
44344434
else:
4435-
concatted = concat(frames, axis=1)
4436-
assert isinstance(concatted, DataFrame)
4437-
df = concatted
4435+
df = concat(frames, axis=1)
44384436

44394437
selection = Selection(self, where=where, start=start, stop=stop)
44404438
# apply the selection filters & axis orderings

0 commit comments

Comments
 (0)