Skip to content

Commit c09c779

Browse files
amaliujiacloud-fan
authored andcommitted
[SPARK-40938][CONNECT] Support Alias for every type of Relation
### What changes were proposed in this pull request? In the past, Connect server can check `alias` for `Read` and `Project`. However for Spark DataFrame, every DataFrame can be chained with `as(alias: String)` thus every Relation/LogicalPlan can have an `alias`. This PR refactors to make this work. ### Why are the changes needed? Improve API coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38415 from amaliujia/every_relation_has_alias. Authored-by: Rui Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent ed2ea7f commit c09c779

File tree

8 files changed

+173
-105
lines changed

8 files changed

+173
-105
lines changed

connector/connect/src/main/protobuf/spark/connect/relations.proto

+12-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ message Relation {
4545
Offset offset = 13;
4646
Deduplicate deduplicate = 14;
4747
Range range = 15;
48+
SubqueryAlias subquery_alias = 16;
4849

4950
Unknown unknown = 999;
5051
}
@@ -56,7 +57,6 @@ message Unknown {}
5657
// Common metadata of all relations.
5758
message RelationCommon {
5859
string source_info = 1;
59-
string alias = 2;
6060
}
6161

6262
// Relation that uses a SQL query to generate the output.
@@ -223,6 +223,7 @@ message Sample {
223223
message Range {
224224
// Optional. Default value = 0
225225
int32 start = 1;
226+
// Required.
226227
int32 end = 2;
227228
// Optional. Default value = 1
228229
Step step = 3;
@@ -238,3 +239,13 @@ message Range {
238239
int32 num_partitions = 1;
239240
}
240241
}
242+
243+
// Relation alias.
244+
message SubqueryAlias {
245+
// Required. The input relation.
246+
Relation input = 1;
247+
// Required. The alias.
248+
string alias = 2;
249+
// Optional. Qualifier of the alias.
250+
repeated string qualifier = 3;
251+
}

connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ package object dsl {
308308
def as(alias: String): Relation = {
309309
Relation
310310
.newBuilder(logicalPlan)
311-
.setCommon(RelationCommon.newBuilder().setAlias(alias))
311+
.setSubqueryAlias(SubqueryAlias.newBuilder().setAlias(alias).setInput(logicalPlan))
312312
.build()
313313
}
314314

connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

+22-30
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
2222

2323
import org.apache.spark.connect.proto
2424
import org.apache.spark.sql.SparkSession
25+
import org.apache.spark.sql.catalyst.AliasIdentifier
2526
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
2627
import org.apache.spark.sql.catalyst.expressions
2728
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression}
@@ -54,8 +55,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
5455
}
5556

5657
rel.getRelTypeCase match {
57-
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead, common)
58-
case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject, common)
58+
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
59+
case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject)
5960
case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter)
6061
case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit)
6162
case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset)
@@ -66,9 +67,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
6667
case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate)
6768
case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
6869
case proto.Relation.RelTypeCase.LOCAL_RELATION =>
69-
transformLocalRelation(rel.getLocalRelation, common)
70+
transformLocalRelation(rel.getLocalRelation)
7071
case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample)
7172
case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange)
73+
case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
74+
transformSubqueryAlias(rel.getSubqueryAlias)
7275
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
7376
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
7477
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -79,6 +82,16 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
7982
session.sessionState.sqlParser.parsePlan(sql.getQuery)
8083
}
8184

85+
private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {
86+
val aliasIdentifier =
87+
if (alias.getQualifierCount > 0) {
88+
AliasIdentifier.apply(alias.getAlias, alias.getQualifierList.asScala.toSeq)
89+
} else {
90+
AliasIdentifier.apply(alias.getAlias)
91+
}
92+
SubqueryAlias(aliasIdentifier, transformRelation(alias.getInput))
93+
}
94+
8295
/**
8396
* All fields of [[proto.Sample]] are optional. However, given those are proto primitive types,
8497
* we cannot differentiate if the field is not or set when the field's value equals to the type
@@ -141,35 +154,21 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
141154
}
142155
}
143156

144-
private def transformLocalRelation(
145-
rel: proto.LocalRelation,
146-
common: Option[proto.RelationCommon]): LogicalPlan = {
157+
private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
147158
val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq
148-
val relation = new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes)
149-
if (common.nonEmpty && common.get.getAlias.nonEmpty) {
150-
logical.SubqueryAlias(identifier = common.get.getAlias, child = relation)
151-
} else {
152-
relation
153-
}
159+
new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes)
154160
}
155161

156162
private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = {
157163
AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))()
158164
}
159165

160-
private def transformReadRel(
161-
rel: proto.Read,
162-
common: Option[proto.RelationCommon]): LogicalPlan = {
166+
private def transformReadRel(rel: proto.Read): LogicalPlan = {
163167
val baseRelation = rel.getReadTypeCase match {
164168
case proto.Read.ReadTypeCase.NAMED_TABLE =>
165169
val multipartIdentifier =
166170
CatalystSqlParser.parseMultipartIdentifier(rel.getNamedTable.getUnparsedIdentifier)
167-
val child = UnresolvedRelation(multipartIdentifier)
168-
if (common.nonEmpty && common.get.getAlias.nonEmpty) {
169-
SubqueryAlias(identifier = common.get.getAlias, child = child)
170-
} else {
171-
child
172-
}
171+
UnresolvedRelation(multipartIdentifier)
173172
case proto.Read.ReadTypeCase.DATA_SOURCE =>
174173
if (rel.getDataSource.getFormat == "") {
175174
throw InvalidPlanInput("DataSource requires a format")
@@ -193,9 +192,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
193192
logical.Filter(condition = transformExpression(rel.getCondition), child = baseRel)
194193
}
195194

196-
private def transformProject(
197-
rel: proto.Project,
198-
common: Option[proto.RelationCommon]): LogicalPlan = {
195+
private def transformProject(rel: proto.Project): LogicalPlan = {
199196
val baseRel = transformRelation(rel.getInput)
200197
// TODO: support the target field for *.
201198
val projection =
@@ -204,12 +201,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
204201
} else {
205202
rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
206203
}
207-
val project = logical.Project(projectList = projection.toSeq, child = baseRel)
208-
if (common.nonEmpty && common.get.getAlias.nonEmpty) {
209-
logical.SubqueryAlias(identifier = common.get.getAlias, child = project)
210-
} else {
211-
project
212-
}
204+
logical.Project(projectList = projection.toSeq, child = baseRel)
213205
}
214206

215207
private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = {

python/pyspark/sql/connect/dataframe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame":
121121
return self.groupBy().agg(exprs)
122122

123123
def alias(self, alias: str) -> "DataFrame":
124-
return DataFrame.withPlan(plan.Project(self._plan).withAlias(alias), session=self._session)
124+
return DataFrame.withPlan(plan.SubqueryAlias(self._plan, alias), session=self._session)
125125

126126
def approxQuantile(self, col: ColumnRef, probabilities: Any, relativeError: Any) -> "DataFrame":
127127
...

python/pyspark/sql/connect/plan.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ def _verify_expressions(self) -> None:
201201
f"Only Expressions or String can be used for projections: '{c}'."
202202
)
203203

204-
def withAlias(self, alias: str) -> LogicalPlan:
205-
self.alias = alias
206-
return self
207-
208204
def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
209205
assert self._child is not None
210206
proj_exprs = []
@@ -217,14 +213,10 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
217213
proj_exprs.append(exp)
218214
else:
219215
proj_exprs.append(self.unresolved_attr(c))
220-
common = proto.RelationCommon()
221-
if self.alias is not None:
222-
common.alias = self.alias
223216

224217
plan = proto.Relation()
225218
plan.project.input.CopyFrom(self._child.plan(session))
226219
plan.project.expressions.extend(proj_exprs)
227-
plan.common.CopyFrom(common)
228220
return plan
229221

230222
def print(self, indent: int = 0) -> str:
@@ -648,6 +640,34 @@ def _repr_html_(self) -> str:
648640
"""
649641

650642

643+
class SubqueryAlias(LogicalPlan):
644+
"""Alias for a relation."""
645+
646+
def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None:
647+
super().__init__(child)
648+
self._alias = alias
649+
650+
def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
651+
rel = proto.Relation()
652+
rel.subquery_alias.alias = self._alias
653+
return rel
654+
655+
def print(self, indent: int = 0) -> str:
656+
c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
657+
return f"{' ' * indent}<SubqueryAlias alias={self._alias}>\n{c_buf}"
658+
659+
def _repr_html_(self) -> str:
660+
return f"""
661+
<ul>
662+
<li>
663+
<b>SubqueryAlias</b><br />
664+
Child: {self._child_repr_()}
665+
Alias: {self._alias}
666+
</li>
667+
</ul>
668+
"""
669+
670+
651671
class SQL(LogicalPlan):
652672
def __init__(self, query: str) -> None:
653673
super().__init__(None)

0 commit comments

Comments
 (0)