diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 89d839f9e..d5f138e11 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -92,6 +92,7 @@ from pandas._typing import ( MaskType, MergeHow, NaPosition, + NDFrameT, ParquetEngine, QuantileInterpolation, RandomState, @@ -548,7 +549,7 @@ class DataFrame(NDFrame, OpsMixin): def lookup(self, row_labels: Sequence, col_labels: Sequence) -> np.ndarray: ... def align( self, - other: DataFrame | Series, + other: NDFrameT, join: JoinHow = ..., axis: Axis | None = ..., level: Level | None = ..., @@ -558,7 +559,7 @@ class DataFrame(NDFrame, OpsMixin): limit: int | None = ..., fill_axis: Axis = ..., broadcast_axis: Axis | None = ..., - ) -> DataFrame: ... + ) -> tuple[DataFrame, NDFrameT]: ... def reindex( self, labels: Axes | None = ..., diff --git a/tests/test_frame.py b/tests/test_frame.py index c0dc5c7f4..064f208bf 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2483,3 +2483,41 @@ def test_xs_frame_new() -> None: s2 = df.xs("num_wings", axis=1) check(assert_type(s1, Union[pd.Series, pd.DataFrame]), pd.DataFrame) check(assert_type(s2, Union[pd.Series, pd.DataFrame]), pd.Series) + + +def test_align() -> None: + df0 = pd.DataFrame( + data=np.array( + [ + ["A0", "A1", "A2", "A3"], + ["B0", "B1", "B2", "B3"], + ["C0", "C1", "C2", "C3"], + ] + ).T, + index=[0, 1, 2, 3], + columns=["A", "B", "C"], + ) + + s0 = pd.Series(data={1: "1", 3: "3", 5: "5"}) + aligned_df0, aligned_s0 = df0.align(s0, axis="index") + check(assert_type(aligned_df0, pd.DataFrame), pd.DataFrame) + check(assert_type(aligned_s0, pd.Series), pd.Series) + + s1 = pd.Series(data={"A": "A", "D": "D"}) + aligned_df0, aligned_s1 = df0.align(s1, axis="columns") + check(assert_type(aligned_df0, pd.DataFrame), pd.DataFrame) + check(assert_type(aligned_s1, pd.Series), pd.Series) + + df1 = pd.DataFrame( + data=np.array( + [ + ["A1", "A3", "A5"], + ["D1", "D3", "D5"], + ] + ).T, + index=[1, 3, 5], + columns=["A", "D"], + ) + aligned_df0, aligned_df1 = df0.align(df1) + check(assert_type(aligned_df0, pd.DataFrame), pd.DataFrame) + check(assert_type(aligned_df1, pd.DataFrame), pd.DataFrame)