diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 0766aa586..f80b65bd6 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -44,6 +44,7 @@ from pandas._typing import ( AggFuncType, AggFuncTypeBase, AggFuncTypeDict, + AnyArrayLike, ArrayLike, Axes, Axis, @@ -1049,9 +1050,9 @@ class DataFrame(NDFrame, OpsMixin): self, right: DataFrame | Series, how: MergeHow = ..., - on: IndexLabel | None = ..., - left_on: IndexLabel | None = ..., - right_on: IndexLabel | None = ..., + on: IndexLabel | AnyArrayLike | None = ..., + left_on: IndexLabel | AnyArrayLike | None = ..., + right_on: IndexLabel | AnyArrayLike | None = ..., left_index: _bool = ..., right_index: _bool = ..., sort: _bool = ..., diff --git a/pandas-stubs/core/reshape/merge.pyi b/pandas-stubs/core/reshape/merge.pyi index b9f441670..8233f51c2 100644 --- a/pandas-stubs/core/reshape/merge.pyi +++ b/pandas-stubs/core/reshape/merge.pyi @@ -6,15 +6,18 @@ from pandas import ( ) from pandas._libs.tslibs import Timedelta -from pandas._typing import Label +from pandas._typing import ( + AnyArrayLike, + Label, +) def merge( left: DataFrame | Series, right: DataFrame | Series, how: str = ..., - on: Label | Sequence | None = ..., - left_on: Label | Sequence | None = ..., - right_on: Label | Sequence | None = ..., + on: Label | Sequence | AnyArrayLike | None = ..., + left_on: Label | Sequence | AnyArrayLike | None = ..., + right_on: Label | Sequence | AnyArrayLike | None = ..., left_index: bool = ..., right_index: bool = ..., sort: bool = ..., @@ -26,9 +29,9 @@ def merge( def merge_ordered( left: DataFrame | Series, right: DataFrame | Series, - on: Label | Sequence | None = ..., - left_on: Label | Sequence | None = ..., - right_on: Label | Sequence | None = ..., + on: Label | Sequence | AnyArrayLike | None = ..., + left_on: Label | Sequence | AnyArrayLike | None = ..., + right_on: Label | Sequence | AnyArrayLike | None = ..., left_by: str | list[str] | None = ..., right_by: str | list[str] | None = ..., fill_method: str | None = ..., @@ -39,8 +42,8 @@ def merge_asof( left: DataFrame | Series, right: DataFrame | Series, on: Label | None = ..., - left_on: Label | None = ..., - right_on: Label | None = ..., + left_on: Label | AnyArrayLike | None = ..., + right_on: Label | AnyArrayLike | None = ..., left_index: bool = ..., right_index: bool = ..., by: str | list[str] | None = ..., diff --git a/tests/test_merge.py b/tests/test_merge.py index 9621ccb30..92ce63acb 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -1,6 +1,10 @@ from __future__ import annotations +import numpy as np import pandas as pd +from typing_extensions import assert_type + +from tests import check def test_types_merge() -> None: @@ -8,3 +12,126 @@ def test_types_merge() -> None: df2 = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [0, 1, 0]}) columns = ["col1", "col2"] df.merge(df2, on=columns) + + check( + assert_type(df.merge(df2, on=pd.Series([1, 2, 3])), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(df.merge(df2, on=pd.Index([1, 2, 3])), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(df.merge(df2, on=np.array([1, 2, 3])), pd.DataFrame), pd.DataFrame + ) + + check( + assert_type( + df.merge(df2, left_on=pd.Series([1, 2, 3]), right_on=pd.Series([1, 2, 3])), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.merge(df2, left_on=pd.Index([1, 2, 3]), right_on=pd.Series([1, 2, 3])), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.merge(df2, left_on=pd.Index([1, 2, 3]), right_on=pd.Index([1, 2, 3])), + pd.DataFrame, + ), + pd.DataFrame, + ) + + check( + assert_type( + df.merge(df2, left_on=np.array([1, 2, 3]), right_on=pd.Series([1, 2, 3])), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.merge(df2, left_on=np.array([1, 2, 3]), right_on=pd.Index([1, 2, 3])), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.merge(df2, left_on=np.array([1, 2, 3]), right_on=np.array([1, 2, 3])), + pd.DataFrame, + ), + pd.DataFrame, + ) + + check( + assert_type(pd.merge(df, df2, on=pd.Series([1, 2, 3])), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(pd.merge(df, df2, on=pd.Index([1, 2, 3])), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(pd.merge(df, df2, on=np.array([1, 2, 3])), pd.DataFrame), + pd.DataFrame, + ) + + check( + assert_type( + pd.merge( + df, df2, left_on=pd.Series([1, 2, 3]), right_on=pd.Series([1, 2, 3]) + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + pd.merge( + df, df2, left_on=pd.Index([1, 2, 3]), right_on=pd.Series([1, 2, 3]) + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + pd.merge( + df, df2, left_on=pd.Index([1, 2, 3]), right_on=pd.Index([1, 2, 3]) + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + + check( + assert_type( + pd.merge( + df, df2, left_on=np.array([1, 2, 3]), right_on=pd.Series([1, 2, 3]) + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + pd.merge( + df, df2, left_on=np.array([1, 2, 3]), right_on=pd.Index([1, 2, 3]) + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + pd.merge( + df, df2, left_on=np.array([1, 2, 3]), right_on=np.array([1, 2, 3]) + ), + pd.DataFrame, + ), + pd.DataFrame, + )