Skip to content

Commit 22c44d0

Browse files
pckSFMarcoGorelli
authored andcommitted
TYP: Overload concat (pandas-dev#41184)
* TYP: Overload concat * Fix NDFrame columns overload * Fix FrameOrSeries * Fix redundant casts * Import Axis outside of TYPE_CHECKING * Remove redundant overloads * Mark mypy/issues/8354 return type issues * Fix missing imports * noop Co-authored-by: Marco Gorelli <[email protected]>
1 parent 786aa1c commit 22c44d0

File tree

6 files changed

+77
-40
lines changed

6 files changed

+77
-40
lines changed

pandas/core/arrays/categorical.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2310,10 +2310,11 @@ def describe(self):
23102310
counts = self.value_counts(dropna=False)
23112311
freqs = counts / counts.sum()
23122312

2313+
from pandas import Index
23132314
from pandas.core.reshape.concat import concat
23142315

23152316
result = concat([counts, freqs], axis=1)
2316-
result.columns = ["counts", "freqs"]
2317+
result.columns = Index(["counts", "freqs"])
23172318
result.index.name = "categories"
23182319

23192320
return result

pandas/core/generic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -5755,7 +5755,8 @@ def astype(
57555755
# GH 19920: retain column metadata after concat
57565756
result = concat(results, axis=1, copy=False)
57575757
result.columns = self.columns
5758-
return result
5758+
# https://github.com/python/mypy/issues/8354
5759+
return cast(FrameOrSeries, result)
57595760

57605761
@final
57615762
def copy(self: FrameOrSeries, deep: bool_t = True) -> FrameOrSeries:
@@ -6118,7 +6119,8 @@ def convert_dtypes(
61186119
for col_name, col in self.items()
61196120
]
61206121
if len(results) > 0:
6121-
return concat(results, axis=1, copy=False)
6122+
# https://github.com/python/mypy/issues/8354
6123+
return cast(FrameOrSeries, concat(results, axis=1, copy=False))
61226124
else:
61236125
return self.copy()
61246126

pandas/core/groupby/generic.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,7 @@ def _aggregate_multiple_funcs(self, arg) -> DataFrame:
308308
res_df = concat(
309309
results.values(), axis=1, keys=[key.label for key in results.keys()]
310310
)
311-
# error: Incompatible return value type (got "Union[DataFrame, Series]",
312-
# expected "DataFrame")
313-
return res_df # type: ignore[return-value]
311+
return res_df
314312

315313
indexed_output = {key.position: val for key, val in results.items()}
316314
output = self.obj._constructor_expanddim(indexed_output, index=None)
@@ -547,9 +545,7 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
547545
result = self.obj._constructor(dtype=np.float64)
548546

549547
result.name = self.obj.name
550-
# error: Incompatible return value type (got "Union[DataFrame, Series]",
551-
# expected "Series")
552-
return result # type: ignore[return-value]
548+
return result
553549

554550
def _can_use_transform_fast(self, result) -> bool:
555551
return True

pandas/core/reshape/concat.py

+64-14
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
TYPE_CHECKING,
99
Hashable,
1010
Iterable,
11+
Literal,
1112
Mapping,
1213
cast,
1314
overload,
1415
)
1516

1617
import numpy as np
1718

19+
from pandas._typing import Axis
1820
from pandas.util._decorators import (
1921
cache_readonly,
2022
deprecate_nonkeyword_arguments,
@@ -57,25 +59,73 @@
5759
@overload
5860
def concat(
5961
objs: Iterable[DataFrame] | Mapping[Hashable, DataFrame],
60-
axis=0,
61-
join: str = "outer",
62-
ignore_index: bool = False,
63-
keys=None,
64-
levels=None,
65-
names=None,
66-
verify_integrity: bool = False,
67-
sort: bool = False,
68-
copy: bool = True,
62+
axis: Literal[0, "index"] = ...,
63+
join: str = ...,
64+
ignore_index: bool = ...,
65+
keys=...,
66+
levels=...,
67+
names=...,
68+
verify_integrity: bool = ...,
69+
sort: bool = ...,
70+
copy: bool = ...,
6971
) -> DataFrame:
7072
...
7173

7274

75+
@overload
76+
def concat(
77+
objs: Iterable[Series] | Mapping[Hashable, Series],
78+
axis: Literal[0, "index"] = ...,
79+
join: str = ...,
80+
ignore_index: bool = ...,
81+
keys=...,
82+
levels=...,
83+
names=...,
84+
verify_integrity: bool = ...,
85+
sort: bool = ...,
86+
copy: bool = ...,
87+
) -> Series:
88+
...
89+
90+
7391
@overload
7492
def concat(
7593
objs: Iterable[NDFrame] | Mapping[Hashable, NDFrame],
76-
axis=0,
77-
join: str = "outer",
78-
ignore_index: bool = False,
94+
axis: Literal[0, "index"] = ...,
95+
join: str = ...,
96+
ignore_index: bool = ...,
97+
keys=...,
98+
levels=...,
99+
names=...,
100+
verify_integrity: bool = ...,
101+
sort: bool = ...,
102+
copy: bool = ...,
103+
) -> DataFrame | Series:
104+
...
105+
106+
107+
@overload
108+
def concat(
109+
objs: Iterable[NDFrame] | Mapping[Hashable, NDFrame],
110+
axis: Literal[1, "columns"],
111+
join: str = ...,
112+
ignore_index: bool = ...,
113+
keys=...,
114+
levels=...,
115+
names=...,
116+
verify_integrity: bool = ...,
117+
sort: bool = ...,
118+
copy: bool = ...,
119+
) -> DataFrame:
120+
...
121+
122+
123+
@overload
124+
def concat(
125+
objs: Iterable[NDFrame] | Mapping[Hashable, NDFrame],
126+
axis: Axis = ...,
127+
join: str = ...,
128+
ignore_index: bool = ...,
79129
keys=None,
80130
levels=None,
81131
names=None,
@@ -89,8 +139,8 @@ def concat(
89139
@deprecate_nonkeyword_arguments(version=None, allowed_args=["objs"])
90140
def concat(
91141
objs: Iterable[NDFrame] | Mapping[Hashable, NDFrame],
92-
axis=0,
93-
join="outer",
142+
axis: Axis = 0,
143+
join: str = "outer",
94144
ignore_index: bool = False,
95145
keys=None,
96146
levels=None,

pandas/core/reshape/melt.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
from __future__ import annotations
22

33
import re
4-
from typing import (
5-
TYPE_CHECKING,
6-
cast,
7-
)
4+
from typing import TYPE_CHECKING
85
import warnings
96

107
import numpy as np
@@ -34,10 +31,7 @@
3431
from pandas.core.tools.numeric import to_numeric
3532

3633
if TYPE_CHECKING:
37-
from pandas import (
38-
DataFrame,
39-
Series,
40-
)
34+
from pandas import DataFrame
4135

4236

4337
@Appender(_shared_docs["melt"] % {"caller": "pd.melt(df, ", "other": "DataFrame.melt"})
@@ -136,7 +130,7 @@ def melt(
136130
for col in id_vars:
137131
id_data = frame.pop(col)
138132
if is_extension_array_dtype(id_data):
139-
id_data = cast("Series", concat([id_data] * K, ignore_index=True))
133+
id_data = concat([id_data] * K, ignore_index=True)
140134
else:
141135
id_data = np.tile(id_data._values, K)
142136
mdata[col] = id_data

pandas/core/reshape/reshape.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
from __future__ import annotations
22

33
import itertools
4-
from typing import (
5-
TYPE_CHECKING,
6-
cast,
7-
)
4+
from typing import TYPE_CHECKING
85

96
import numpy as np
107

@@ -1059,10 +1056,7 @@ def get_empty_frame(data) -> DataFrame:
10591056
)
10601057
sparse_series.append(Series(data=sarr, index=index, name=col))
10611058

1062-
out = concat(sparse_series, axis=1, copy=False)
1063-
# TODO: overload concat with Literal for axis
1064-
out = cast(DataFrame, out)
1065-
return out
1059+
return concat(sparse_series, axis=1, copy=False)
10661060

10671061
else:
10681062
# take on axis=1 + transpose to ensure ndarray layout is column-major

0 commit comments

Comments
 (0)