Skip to content

Commit 522b3b6

Browse files
Fix result type of DataFrame.align (#577)
* Fix result type of `DataFrame.align` * Use `NDFrameT` instead of newly added `_SeriesOrFrameT` * Add tests for `DataFrame.align`
1 parent 41fe0f5 commit 522b3b6

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

pandas-stubs/core/frame.pyi

+3-2
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ from pandas._typing import (
9292
MaskType,
9393
MergeHow,
9494
NaPosition,
95+
NDFrameT,
9596
ParquetEngine,
9697
QuantileInterpolation,
9798
RandomState,
@@ -548,7 +549,7 @@ class DataFrame(NDFrame, OpsMixin):
548549
def lookup(self, row_labels: Sequence, col_labels: Sequence) -> np.ndarray: ...
549550
def align(
550551
self,
551-
other: DataFrame | Series,
552+
other: NDFrameT,
552553
join: JoinHow = ...,
553554
axis: Axis | None = ...,
554555
level: Level | None = ...,
@@ -558,7 +559,7 @@ class DataFrame(NDFrame, OpsMixin):
558559
limit: int | None = ...,
559560
fill_axis: Axis = ...,
560561
broadcast_axis: Axis | None = ...,
561-
) -> DataFrame: ...
562+
) -> tuple[DataFrame, NDFrameT]: ...
562563
def reindex(
563564
self,
564565
labels: Axes | None = ...,

tests/test_frame.py

+38
Original file line numberDiff line numberDiff line change
@@ -2483,3 +2483,41 @@ def test_xs_frame_new() -> None:
24832483
s2 = df.xs("num_wings", axis=1)
24842484
check(assert_type(s1, Union[pd.Series, pd.DataFrame]), pd.DataFrame)
24852485
check(assert_type(s2, Union[pd.Series, pd.DataFrame]), pd.Series)
2486+
2487+
2488+
def test_align() -> None:
2489+
df0 = pd.DataFrame(
2490+
data=np.array(
2491+
[
2492+
["A0", "A1", "A2", "A3"],
2493+
["B0", "B1", "B2", "B3"],
2494+
["C0", "C1", "C2", "C3"],
2495+
]
2496+
).T,
2497+
index=[0, 1, 2, 3],
2498+
columns=["A", "B", "C"],
2499+
)
2500+
2501+
s0 = pd.Series(data={1: "1", 3: "3", 5: "5"})
2502+
aligned_df0, aligned_s0 = df0.align(s0, axis="index")
2503+
check(assert_type(aligned_df0, pd.DataFrame), pd.DataFrame)
2504+
check(assert_type(aligned_s0, pd.Series), pd.Series)
2505+
2506+
s1 = pd.Series(data={"A": "A", "D": "D"})
2507+
aligned_df0, aligned_s1 = df0.align(s1, axis="columns")
2508+
check(assert_type(aligned_df0, pd.DataFrame), pd.DataFrame)
2509+
check(assert_type(aligned_s1, pd.Series), pd.Series)
2510+
2511+
df1 = pd.DataFrame(
2512+
data=np.array(
2513+
[
2514+
["A1", "A3", "A5"],
2515+
["D1", "D3", "D5"],
2516+
]
2517+
).T,
2518+
index=[1, 3, 5],
2519+
columns=["A", "D"],
2520+
)
2521+
aligned_df0, aligned_df1 = df0.align(df1)
2522+
check(assert_type(aligned_df0, pd.DataFrame), pd.DataFrame)
2523+
check(assert_type(aligned_df1, pd.DataFrame), pd.DataFrame)

0 commit comments

Comments
 (0)