Skip to content

Commit 713008d

Browse files
authored
Add AnyArrayLike for merge on arguments (#231)
* Add AnyArrayLike for merge on arguments * Add checks
1 parent a02385e commit 713008d

File tree

3 files changed

+143
-12
lines changed

3 files changed

+143
-12
lines changed

pandas-stubs/core/frame.pyi

+4-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ from pandas._typing import (
4444
AggFuncType,
4545
AggFuncTypeBase,
4646
AggFuncTypeDict,
47+
AnyArrayLike,
4748
ArrayLike,
4849
Axes,
4950
Axis,
@@ -1049,9 +1050,9 @@ class DataFrame(NDFrame, OpsMixin):
10491050
self,
10501051
right: DataFrame | Series,
10511052
how: MergeHow = ...,
1052-
on: IndexLabel | None = ...,
1053-
left_on: IndexLabel | None = ...,
1054-
right_on: IndexLabel | None = ...,
1053+
on: IndexLabel | AnyArrayLike | None = ...,
1054+
left_on: IndexLabel | AnyArrayLike | None = ...,
1055+
right_on: IndexLabel | AnyArrayLike | None = ...,
10551056
left_index: _bool = ...,
10561057
right_index: _bool = ...,
10571058
sort: _bool = ...,

pandas-stubs/core/reshape/merge.pyi

+12-9
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@ from pandas import (
66
)
77

88
from pandas._libs.tslibs import Timedelta
9-
from pandas._typing import Label
9+
from pandas._typing import (
10+
AnyArrayLike,
11+
Label,
12+
)
1013

1114
def merge(
1215
left: DataFrame | Series,
1316
right: DataFrame | Series,
1417
how: str = ...,
15-
on: Label | Sequence | None = ...,
16-
left_on: Label | Sequence | None = ...,
17-
right_on: Label | Sequence | None = ...,
18+
on: Label | Sequence | AnyArrayLike | None = ...,
19+
left_on: Label | Sequence | AnyArrayLike | None = ...,
20+
right_on: Label | Sequence | AnyArrayLike | None = ...,
1821
left_index: bool = ...,
1922
right_index: bool = ...,
2023
sort: bool = ...,
@@ -26,9 +29,9 @@ def merge(
2629
def merge_ordered(
2730
left: DataFrame | Series,
2831
right: DataFrame | Series,
29-
on: Label | Sequence | None = ...,
30-
left_on: Label | Sequence | None = ...,
31-
right_on: Label | Sequence | None = ...,
32+
on: Label | Sequence | AnyArrayLike | None = ...,
33+
left_on: Label | Sequence | AnyArrayLike | None = ...,
34+
right_on: Label | Sequence | AnyArrayLike | None = ...,
3235
left_by: str | list[str] | None = ...,
3336
right_by: str | list[str] | None = ...,
3437
fill_method: str | None = ...,
@@ -39,8 +42,8 @@ def merge_asof(
3942
left: DataFrame | Series,
4043
right: DataFrame | Series,
4144
on: Label | None = ...,
42-
left_on: Label | None = ...,
43-
right_on: Label | None = ...,
45+
left_on: Label | AnyArrayLike | None = ...,
46+
right_on: Label | AnyArrayLike | None = ...,
4447
left_index: bool = ...,
4548
right_index: bool = ...,
4649
by: str | list[str] | None = ...,

tests/test_merge.py

+127
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,137 @@
11
from __future__ import annotations
22

3+
import numpy as np
34
import pandas as pd
5+
from typing_extensions import assert_type
6+
7+
from tests import check
48

59

610
def test_types_merge() -> None:
711
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5]})
812
df2 = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [0, 1, 0]})
913
columns = ["col1", "col2"]
1014
df.merge(df2, on=columns)
15+
16+
check(
17+
assert_type(df.merge(df2, on=pd.Series([1, 2, 3])), pd.DataFrame), pd.DataFrame
18+
)
19+
check(
20+
assert_type(df.merge(df2, on=pd.Index([1, 2, 3])), pd.DataFrame), pd.DataFrame
21+
)
22+
check(
23+
assert_type(df.merge(df2, on=np.array([1, 2, 3])), pd.DataFrame), pd.DataFrame
24+
)
25+
26+
check(
27+
assert_type(
28+
df.merge(df2, left_on=pd.Series([1, 2, 3]), right_on=pd.Series([1, 2, 3])),
29+
pd.DataFrame,
30+
),
31+
pd.DataFrame,
32+
)
33+
check(
34+
assert_type(
35+
df.merge(df2, left_on=pd.Index([1, 2, 3]), right_on=pd.Series([1, 2, 3])),
36+
pd.DataFrame,
37+
),
38+
pd.DataFrame,
39+
)
40+
check(
41+
assert_type(
42+
df.merge(df2, left_on=pd.Index([1, 2, 3]), right_on=pd.Index([1, 2, 3])),
43+
pd.DataFrame,
44+
),
45+
pd.DataFrame,
46+
)
47+
48+
check(
49+
assert_type(
50+
df.merge(df2, left_on=np.array([1, 2, 3]), right_on=pd.Series([1, 2, 3])),
51+
pd.DataFrame,
52+
),
53+
pd.DataFrame,
54+
)
55+
check(
56+
assert_type(
57+
df.merge(df2, left_on=np.array([1, 2, 3]), right_on=pd.Index([1, 2, 3])),
58+
pd.DataFrame,
59+
),
60+
pd.DataFrame,
61+
)
62+
check(
63+
assert_type(
64+
df.merge(df2, left_on=np.array([1, 2, 3]), right_on=np.array([1, 2, 3])),
65+
pd.DataFrame,
66+
),
67+
pd.DataFrame,
68+
)
69+
70+
check(
71+
assert_type(pd.merge(df, df2, on=pd.Series([1, 2, 3])), pd.DataFrame),
72+
pd.DataFrame,
73+
)
74+
check(
75+
assert_type(pd.merge(df, df2, on=pd.Index([1, 2, 3])), pd.DataFrame),
76+
pd.DataFrame,
77+
)
78+
check(
79+
assert_type(pd.merge(df, df2, on=np.array([1, 2, 3])), pd.DataFrame),
80+
pd.DataFrame,
81+
)
82+
83+
check(
84+
assert_type(
85+
pd.merge(
86+
df, df2, left_on=pd.Series([1, 2, 3]), right_on=pd.Series([1, 2, 3])
87+
),
88+
pd.DataFrame,
89+
),
90+
pd.DataFrame,
91+
)
92+
check(
93+
assert_type(
94+
pd.merge(
95+
df, df2, left_on=pd.Index([1, 2, 3]), right_on=pd.Series([1, 2, 3])
96+
),
97+
pd.DataFrame,
98+
),
99+
pd.DataFrame,
100+
)
101+
check(
102+
assert_type(
103+
pd.merge(
104+
df, df2, left_on=pd.Index([1, 2, 3]), right_on=pd.Index([1, 2, 3])
105+
),
106+
pd.DataFrame,
107+
),
108+
pd.DataFrame,
109+
)
110+
111+
check(
112+
assert_type(
113+
pd.merge(
114+
df, df2, left_on=np.array([1, 2, 3]), right_on=pd.Series([1, 2, 3])
115+
),
116+
pd.DataFrame,
117+
),
118+
pd.DataFrame,
119+
)
120+
check(
121+
assert_type(
122+
pd.merge(
123+
df, df2, left_on=np.array([1, 2, 3]), right_on=pd.Index([1, 2, 3])
124+
),
125+
pd.DataFrame,
126+
),
127+
pd.DataFrame,
128+
)
129+
check(
130+
assert_type(
131+
pd.merge(
132+
df, df2, left_on=np.array([1, 2, 3]), right_on=np.array([1, 2, 3])
133+
),
134+
pd.DataFrame,
135+
),
136+
pd.DataFrame,
137+
)

0 commit comments

Comments
 (0)