Skip to content

Commit ed2ea7f

Browse files
amaliujiaHyukjinKwon
authored andcommitted
[SPARK-40915][CONNECT] Improve on in Join in Python client
### What changes were proposed in this pull request? 1. Fix Join's `on` from ANY to concrete types (e.g. str, list of str, column, list of columns) 2. When `on` is str or list of str, it should generate a proto plan with `using_columns` 3. When `on` is column or list of column, it should generate a proto plan with `join_condition`. ### Why are the changes needed? Improve API coverage ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38393 from amaliujia/python_join_on_improvement. Authored-by: Rui Wang <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent b3ed0c1 commit ed2ea7f

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

python/pyspark/sql/connect/dataframe.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,13 @@ def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame:
217217
def head(self, n: int) -> Optional["pandas.DataFrame"]:
218218
return self.limit(n).toPandas()
219219

220-
# TODO(martin.grund) fix mypu
221-
def join(self, other: "DataFrame", on: Any, how: Optional[str] = None) -> "DataFrame":
220+
# TODO: extend `on` to also be type List[ColumnRef].
221+
def join(
222+
self,
223+
other: "DataFrame",
224+
on: Optional[Union[str, List[str], ColumnRef]] = None,
225+
how: Optional[str] = None,
226+
) -> "DataFrame":
222227
if self._plan is None:
223228
raise Exception("Cannot join when self._plan is empty.")
224229
if other._plan is None:

python/pyspark/sql/connect/plan.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def __init__(
537537
self,
538538
left: Optional["LogicalPlan"],
539539
right: "LogicalPlan",
540-
on: "ColumnOrString",
540+
on: Optional[Union[str, List[str], ColumnRef]],
541541
how: Optional[str],
542542
) -> None:
543543
super().__init__(left)
@@ -575,7 +575,14 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
575575
rel = proto.Relation()
576576
rel.join.left.CopyFrom(self.left.plan(session))
577577
rel.join.right.CopyFrom(self.right.plan(session))
578-
rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
578+
if self.on is not None:
579+
if not isinstance(self.on, list):
580+
if isinstance(self.on, str):
581+
rel.join.using_columns.append(self.on)
582+
else:
583+
rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
584+
else:
585+
rel.join.using_columns.extend(self.on)
579586
rel.join.join_type = self.how
580587
return rel
581588

python/pyspark/sql/tests/connect/test_connect_plan_only.py

+17
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,23 @@ def test_simple_project(self):
3737
self.assertIsNotNone(plan.root, "Root relation must be set")
3838
self.assertIsNotNone(plan.root.read)
3939

40+
def test_join_using_columns(self):
41+
left_input = self.connect.readTable(table_name=self.tbl_name)
42+
right_input = self.connect.readTable(table_name=self.tbl_name)
43+
plan = left_input.join(other=right_input, on="join_column")._plan.to_proto(self.connect)
44+
self.assertEqual(len(plan.root.join.using_columns), 1)
45+
46+
plan2 = left_input.join(other=right_input, on=["col1", "col2"])._plan.to_proto(self.connect)
47+
self.assertEqual(len(plan2.root.join.using_columns), 2)
48+
49+
def test_join_condition(self):
50+
left_input = self.connect.readTable(table_name=self.tbl_name)
51+
right_input = self.connect.readTable(table_name=self.tbl_name)
52+
plan = left_input.join(
53+
other=right_input, on=left_input.name == right_input.name
54+
)._plan.to_proto(self.connect)
55+
self.assertIsNotNone(plan.root.join.join_condition)
56+
4057
def test_filter(self):
4158
df = self.connect.readTable(table_name=self.tbl_name)
4259
plan = df.filter(df.col_name > 3)._plan.to_proto(self.connect)

0 commit comments

Comments
 (0)