Skip to content

Commit dee2bac

Browse files
authored
TYP: fix NDFrame.align (#51867)
* TYP: fix NDFrame.align * avoid ignores outside NDFrame.align (but add them inside NDFrame.align) * added comment * move to pandas._typing * undo pyright change
1 parent 4d74fbd commit dee2bac

File tree

4 files changed

+34
-23
lines changed

4 files changed

+34
-23
lines changed

pandas/_typing.py

+3
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@
125125
# Series is passed into a function, a Series is always returned and if a DataFrame is
126126
# passed in, a DataFrame is always returned.
127127
NDFrameT = TypeVar("NDFrameT", bound="NDFrame")
128+
# same as NDFrameT, needed when binding two pairs of parameters to potentially
129+
# separate NDFrame-subclasses (see NDFrame.align)
130+
NDFrameTb = TypeVar("NDFrameTb", bound="NDFrame")
128131

129132
NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index")
130133

pandas/core/frame.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@
225225
Level,
226226
MergeHow,
227227
NaPosition,
228+
NDFrameT,
228229
PythonFuncType,
229230
QuantileInterpolation,
230231
ReadBuffer,
@@ -4997,7 +4998,7 @@ def _reindex_multi(
49974998
@doc(NDFrame.align, **_shared_doc_kwargs)
49984999
def align(
49995000
self,
5000-
other: DataFrame,
5001+
other: NDFrameT,
50015002
join: AlignJoin = "outer",
50025003
axis: Axis | None = None,
50035004
level: Level = None,
@@ -5007,7 +5008,7 @@ def align(
50075008
limit: int | None = None,
50085009
fill_axis: Axis = 0,
50095010
broadcast_axis: Axis | None = None,
5010-
) -> DataFrame:
5011+
) -> tuple[DataFrame, NDFrameT]:
50115012
return super().align(
50125013
other,
50135014
join=join,
@@ -7771,9 +7772,7 @@ def to_series(right):
77717772
)
77727773

77737774
left, right = left.align(
7774-
# error: Argument 1 to "align" of "DataFrame" has incompatible
7775-
# type "Series"; expected "DataFrame"
7776-
right, # type: ignore[arg-type]
7775+
right,
77777776
join="outer",
77787777
axis=axis,
77797778
level=level,

pandas/core/generic.py

+24-16
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
Manager,
6868
NaPosition,
6969
NDFrameT,
70+
NDFrameTb,
7071
RandomState,
7172
Renamer,
7273
Scalar,
@@ -198,7 +199,6 @@
198199
from pandas.core.indexers.objects import BaseIndexer
199200
from pandas.core.resample import Resampler
200201

201-
202202
# goal is to be able to define the docs close to function, while still being
203203
# able to share
204204
_shared_docs = {**_shared_docs}
@@ -9297,7 +9297,7 @@ def compare(
92979297
@doc(**_shared_doc_kwargs)
92989298
def align(
92999299
self: NDFrameT,
9300-
other: NDFrameT,
9300+
other: NDFrameTb,
93019301
join: AlignJoin = "outer",
93029302
axis: Axis | None = None,
93039303
level: Level = None,
@@ -9307,7 +9307,7 @@ def align(
93079307
limit: int | None = None,
93089308
fill_axis: Axis = 0,
93099309
broadcast_axis: Axis | None = None,
9310-
) -> NDFrameT:
9310+
) -> tuple[NDFrameT, NDFrameTb]:
93119311
"""
93129312
Align two objects on their axes with the specified join method.
93139313
@@ -9428,8 +9428,10 @@ def align(
94289428
df = cons(
94299429
{c: self for c in other.columns}, **other._construct_axes_dict()
94309430
)
9431-
return df._align_frame(
9432-
other,
9431+
# error: Incompatible return value type (got "Tuple[DataFrame,
9432+
# DataFrame]", expected "Tuple[NDFrameT, NDFrameTb]")
9433+
return df._align_frame( # type: ignore[return-value]
9434+
other, # type: ignore[arg-type]
94339435
join=join,
94349436
axis=axis,
94359437
level=level,
@@ -9446,7 +9448,9 @@ def align(
94469448
df = cons(
94479449
{c: other for c in self.columns}, **self._construct_axes_dict()
94489450
)
9449-
return self._align_frame(
9451+
# error: Incompatible return value type (got "Tuple[NDFrameT,
9452+
# DataFrame]", expected "Tuple[NDFrameT, NDFrameTb]")
9453+
return self._align_frame( # type: ignore[return-value]
94509454
df,
94519455
join=join,
94529456
axis=axis,
@@ -9461,7 +9465,9 @@ def align(
94619465
if axis is not None:
94629466
axis = self._get_axis_number(axis)
94639467
if isinstance(other, ABCDataFrame):
9464-
return self._align_frame(
9468+
# error: Incompatible return value type (got "Tuple[NDFrameT, DataFrame]",
9469+
# expected "Tuple[NDFrameT, NDFrameTb]")
9470+
return self._align_frame( # type: ignore[return-value]
94659471
other,
94669472
join=join,
94679473
axis=axis,
@@ -9473,7 +9479,9 @@ def align(
94739479
fill_axis=fill_axis,
94749480
)
94759481
elif isinstance(other, ABCSeries):
9476-
return self._align_series(
9482+
# error: Incompatible return value type (got "Tuple[NDFrameT, Series]",
9483+
# expected "Tuple[NDFrameT, NDFrameTb]")
9484+
return self._align_series( # type: ignore[return-value]
94779485
other,
94789486
join=join,
94799487
axis=axis,
@@ -9489,8 +9497,8 @@ def align(
94899497

94909498
@final
94919499
def _align_frame(
9492-
self,
9493-
other,
9500+
self: NDFrameT,
9501+
other: DataFrame,
94949502
join: AlignJoin = "outer",
94959503
axis: Axis | None = None,
94969504
level=None,
@@ -9499,7 +9507,7 @@ def _align_frame(
94999507
method=None,
95009508
limit=None,
95019509
fill_axis: Axis = 0,
9502-
):
9510+
) -> tuple[NDFrameT, DataFrame]:
95039511
# defaults
95049512
join_index, join_columns = None, None
95059513
ilidx, iridx = None, None
@@ -9553,8 +9561,8 @@ def _align_frame(
95539561

95549562
@final
95559563
def _align_series(
9556-
self,
9557-
other,
9564+
self: NDFrameT,
9565+
other: Series,
95589566
join: AlignJoin = "outer",
95599567
axis: Axis | None = None,
95609568
level=None,
@@ -9563,7 +9571,7 @@ def _align_series(
95639571
method=None,
95649572
limit=None,
95659573
fill_axis: Axis = 0,
9566-
):
9574+
) -> tuple[NDFrameT, Series]:
95679575
is_series = isinstance(self, ABCSeries)
95689576
if copy and using_copy_on_write():
95699577
copy = False
@@ -12798,8 +12806,8 @@ def _doc_params(cls):
1279812806

1279912807

1280012808
def _align_as_utc(
12801-
left: NDFrameT, right: NDFrameT, join_index: Index | None
12802-
) -> tuple[NDFrameT, NDFrameT]:
12809+
left: NDFrameT, right: NDFrameTb, join_index: Index | None
12810+
) -> tuple[NDFrameT, NDFrameTb]:
1280312811
"""
1280412812
If we are aligning timezone-aware DatetimeIndexes and the timezones
1280512813
do not match, convert both to UTC.

pandas/core/series.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@
166166
IndexLabel,
167167
Level,
168168
NaPosition,
169+
NDFrameT,
169170
NumpySorter,
170171
NumpyValueArrayLike,
171172
QuantileInterpolation,
@@ -4571,7 +4572,7 @@ def _needs_reindex_multi(self, axes, method, level) -> bool:
45714572
)
45724573
def align(
45734574
self,
4574-
other: Series,
4575+
other: NDFrameT,
45754576
join: AlignJoin = "outer",
45764577
axis: Axis | None = None,
45774578
level: Level = None,
@@ -4581,7 +4582,7 @@ def align(
45814582
limit: int | None = None,
45824583
fill_axis: Axis = 0,
45834584
broadcast_axis: Axis | None = None,
4584-
) -> Series:
4585+
) -> tuple[Series, NDFrameT]:
45854586
return super().align(
45864587
other,
45874588
join=join,

0 commit comments

Comments
 (0)