Skip to content

Commit 015c0c0

Browse files
TYP: use overload to refine return type of set_axis (#40197)
* try typing set_axis * fixup overloads * remove return type from base definition * searching in the sun for another overload * bool * allow defaults in bool case * type non-overloaded arguments Co-authored-by: Simon Hawkins <[email protected]>
1 parent 1cc40fa commit 015c0c0

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

pandas/core/frame.py

+20
Original file line numberDiff line numberDiff line change
@@ -4541,6 +4541,26 @@ def align(
45414541
broadcast_axis=broadcast_axis,
45424542
)
45434543

4544+
@overload
4545+
def set_axis(
4546+
self, labels, axis: Axis = ..., inplace: Literal[False] = ...
4547+
) -> DataFrame:
4548+
...
4549+
4550+
@overload
4551+
def set_axis(self, labels, axis: Axis, inplace: Literal[True]) -> None:
4552+
...
4553+
4554+
@overload
4555+
def set_axis(self, labels, *, inplace: Literal[True]) -> None:
4556+
...
4557+
4558+
@overload
4559+
def set_axis(
4560+
self, labels, axis: Axis = ..., inplace: bool = ...
4561+
) -> Optional[DataFrame]:
4562+
...
4563+
45444564
@Appender(
45454565
"""
45464566
Examples

pandas/core/generic.py

+25
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Type,
2525
Union,
2626
cast,
27+
overload,
2728
)
2829
import warnings
2930
import weakref
@@ -162,6 +163,8 @@
162163
from pandas.io.formats.printing import pprint_thing
163164

164165
if TYPE_CHECKING:
166+
from typing import Literal
167+
165168
from pandas._libs.tslibs import BaseOffset
166169

167170
from pandas.core.frame import DataFrame
@@ -682,6 +685,28 @@ def _obj_with_exclusions(self: FrameOrSeries) -> FrameOrSeries:
682685
""" internal compat with SelectionMixin """
683686
return self
684687

688+
@overload
689+
def set_axis(
690+
self: FrameOrSeries, labels, axis: Axis = ..., inplace: Literal[False] = ...
691+
) -> FrameOrSeries:
692+
...
693+
694+
@overload
695+
def set_axis(
696+
self: FrameOrSeries, labels, axis: Axis, inplace: Literal[True]
697+
) -> None:
698+
...
699+
700+
@overload
701+
def set_axis(self: FrameOrSeries, labels, *, inplace: Literal[True]) -> None:
702+
...
703+
704+
@overload
705+
def set_axis(
706+
self: FrameOrSeries, labels, axis: Axis = ..., inplace: bool = ...
707+
) -> Optional[FrameOrSeries]:
708+
...
709+
685710
def set_axis(self, labels, axis: Axis = 0, inplace: bool = False):
686711
"""
687712
Assign desired index to given axis.

pandas/core/series.py

+23
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Type,
2121
Union,
2222
cast,
23+
overload,
2324
)
2425
import warnings
2526

@@ -142,6 +143,8 @@
142143
import pandas.plotting
143144

144145
if TYPE_CHECKING:
146+
from typing import Literal
147+
145148
from pandas._typing import (
146149
TimedeltaConvertibleTypes,
147150
TimestampConvertibleTypes,
@@ -4342,6 +4345,26 @@ def rename(
43424345
else:
43434346
return self._set_name(index, inplace=inplace)
43444347

4348+
@overload
4349+
def set_axis(
4350+
self, labels, axis: Axis = ..., inplace: Literal[False] = ...
4351+
) -> Series:
4352+
...
4353+
4354+
@overload
4355+
def set_axis(self, labels, axis: Axis, inplace: Literal[True]) -> None:
4356+
...
4357+
4358+
@overload
4359+
def set_axis(self, labels, *, inplace: Literal[True]) -> None:
4360+
...
4361+
4362+
@overload
4363+
def set_axis(
4364+
self, labels, axis: Axis = ..., inplace: bool = ...
4365+
) -> Optional[Series]:
4366+
...
4367+
43454368
@Appender(
43464369
"""
43474370
Examples

0 commit comments

Comments
 (0)